Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFoS Support #63

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
bf6f994
1. Support NLP non-distribued training
allwefantasy Oct 11, 2017
3c3fd2d
set test_mode to True which can avoid to kafka dependency
allwefantasy Oct 13, 2017
e0cdad2
clean some file
allwefantasy Oct 13, 2017
afbd95c
Add TFoS support
allwefantasy Oct 14, 2017
a1c1fa0
Fix TFoSTest
allwefantasy Oct 14, 2017
06cad2e
1. Support NLP non-distribued training
allwefantasy Oct 11, 2017
931d603
set test_mode to True which can avoid to kafka dependency
allwefantasy Oct 13, 2017
d3c8a0c
clean some file
allwefantasy Oct 13, 2017
6566c83
Add TFoS support
allwefantasy Oct 14, 2017
e99862b
Fix TFoSTest
allwefantasy Oct 14, 2017
ce68c29
add TFoS test
allwefantasy Oct 15, 2017
02424dc
fix conflict
allwefantasy Oct 15, 2017
196128a
example
allwefantasy Oct 18, 2017
4e8b11e
move tensorflow map_fun to tf_text_test.py and modify the signature t…
allwefantasy Oct 18, 2017
15a0c40
1. Support NLP non-distribued training
allwefantasy Oct 11, 2017
08e61f3
set test_mode to True which can avoid to kafka dependency
allwefantasy Oct 13, 2017
e51c508
clean some file
allwefantasy Oct 13, 2017
65a4694
[#55] fix TFImageTransformer example in docs (#58)
phi-dbq Oct 14, 2017
b812764
move tensorflow map_fun to tf_text_test.py and modify the signature t…
allwefantasy Oct 18, 2017
e277b24
fix code style in TFTextTransformer
allwefantasy Oct 18, 2017
edd359c
make sure TFTextTransformer will pass the ./python/run-tests.sh
allwefantasy Oct 18, 2017
6dc76e2
fix conflict
allwefantasy Oct 18, 2017
b2550c3
fix conflict
allwefantasy Oct 18, 2017
190a4bd
merge nlp-support
allwefantasy Oct 18, 2017
6012b35
add tensorflowonspark to python/requirements.txt
allwefantasy Oct 18, 2017
0039a5a
fix pickle import for python 2/3
allwefantasy Oct 18, 2017
4574d91
rm /tmp/mock-kafka/ before run test
allwefantasy Oct 18, 2017
4e2202a
kafka file conflics
allwefantasy Oct 18, 2017
43c4cf7
fix
allwefantasy Oct 18, 2017
713946b
fix pickle in python 3
allwefantasy Oct 18, 2017
63e3265
changing TFoSTest.py to TFoSExample.py to avoid unit test
allwefantasy Oct 18, 2017
bbfcb20
move TFoSExample from tests
allwefantasy Oct 18, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ README.org
.cache/
.history/
.lib/
.coverage
dist/*
target/
lib_managed/
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,16 @@ Spark DataFrames are a natural construct for applying deep learning models to a

```python
from sparkdl import readImages, TFImageTransformer
import sparkdl.graph.utils as tfx
from sparkdl.transformers import utils
import tensorflow as tf

g = tf.Graph()
with g.as_default():
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
image_arr = utils.imageInputPlaceholder()
resized_images = tf.image.resize_images(image_arr, (299, 299))
# the following step is not necessary for this graph, but can be for graphs with variables, etc
frozen_graph = utils.stripAndFreezeGraph(g.as_graph_def(add_shapes=True), tf.Session(graph=g),
[resized_images])
frozen_graph = tfx.strip_and_freeze_until([resized_images], graph, sess,
return_graph=True)

transformer = TFImageTransformer(inputCol="image", outputCol="predictions", graph=frozen_graph,
inputTensor=image_arr, outputTensor=resized_images,
Expand Down Expand Up @@ -241,7 +241,7 @@ registerKerasImageUDF("my_keras_inception_udf", InceptionV3(weights="imagenet"),

```

### Estimator

## Releases:
* 0.1.0 initial release

175 changes: 175 additions & 0 deletions python/TFoSExample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from pyspark.sql import SparkSession
from sparkdl.estimators.tf_text_file_estimator import TFTextFileEstimator, KafkaMockServer
from sparkdl.transformers.tf_text import TFTextTransformer


def map_fun(args={}, ctx=None, _read_data=None):
from tensorflowonspark import TFNode
from datetime import datetime
import math
import numpy
import tensorflow as tf
import time

print(args)

EMBEDDING_SIZE = args["embedding_size"]
feature = args['feature']
label = args['label']
params = args['params']['fitParam'][0]
SEQUENCE_LENGTH = 64

clusterMode = False if ctx is None else True

if clusterMode and ctx.job_name == "ps":
time.sleep((ctx.worker_num + 1) * 5)

if clusterMode:
cluster, server = TFNode.start_cluster_server(ctx, 1)

def feed_dict(batch):
# Convert from dict of named arrays to two numpy arrays of the proper type
features = []
for i in batch:
features.append(i['sentence_matrix'])

# print("{} {}".format(feature, features))
return features

def build_graph():
encoder_variables_dict = {
"encoder_w1": tf.Variable(
tf.random_normal([SEQUENCE_LENGTH * EMBEDDING_SIZE, 256]), name="encoder_w1"),
"encoder_b1": tf.Variable(tf.random_normal([256]), name="encoder_b1"),
"encoder_w2": tf.Variable(tf.random_normal([256, 128]), name="encoder_w2"),
"encoder_b2": tf.Variable(tf.random_normal([128]), name="encoder_b2")
}

def encoder(x, name="encoder"):
with tf.name_scope(name):
encoder_w1 = encoder_variables_dict["encoder_w1"]
encoder_b1 = encoder_variables_dict["encoder_b1"]

layer_1 = tf.nn.sigmoid(tf.matmul(x, encoder_w1) + encoder_b1)

encoder_w2 = encoder_variables_dict["encoder_w2"]
encoder_b2 = encoder_variables_dict["encoder_b2"]

layer_2 = tf.nn.sigmoid(tf.matmul(layer_1, encoder_w2) + encoder_b2)
return layer_2

def decoder(x, name="decoder"):
with tf.name_scope(name):
decoder_w1 = tf.Variable(tf.random_normal([128, 256]))
decoder_b1 = tf.Variable(tf.random_normal([256]))

layer_1 = tf.nn.sigmoid(tf.matmul(x, decoder_w1) + decoder_b1)

decoder_w2 = tf.Variable(
tf.random_normal([256, SEQUENCE_LENGTH * EMBEDDING_SIZE]))
decoder_b2 = tf.Variable(
tf.random_normal([SEQUENCE_LENGTH * EMBEDDING_SIZE]))

layer_2 = tf.nn.sigmoid(tf.matmul(layer_1, decoder_w2) + decoder_b2)
return layer_2

tf.reset_default_graph

input_x = tf.placeholder(tf.float32, [None, SEQUENCE_LENGTH, EMBEDDING_SIZE], name="input_x")
flattened = tf.reshape(input_x,
[-1, SEQUENCE_LENGTH * EMBEDDING_SIZE])

encoder_op = encoder(flattened)

tf.add_to_collection('encoder_op', encoder_op)

y_pred = decoder(encoder_op)

y_true = flattened

with tf.name_scope("xent"):
consine = tf.div(tf.reduce_sum(tf.multiply(y_pred, y_true), 1),
tf.multiply(tf.sqrt(tf.reduce_sum(tf.multiply(y_pred, y_pred), 1)),
tf.sqrt(tf.reduce_sum(tf.multiply(y_true, y_true), 1))))
xent = tf.reduce_sum(tf.subtract(tf.constant(1.0), consine))
tf.summary.scalar("xent", xent)

with tf.name_scope("train"):
# train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(xent)
train_step = tf.train.RMSPropOptimizer(0.01).minimize(xent)
summ = tf.summary.merge_all()
global_step = tf.Variable(0)
init_op = tf.global_variables_initializer()
return input_x, init_op, train_step, xent, global_step, summ

def train_with_cluster(input_x, init_op, train_step, xent, global_step, summ):

logdir = TFNode.hdfs_path(ctx, params['model']) if clusterMode else None
sv = tf.train.Supervisor(is_chief=ctx.task_index == 0,
logdir=logdir,
init_op=init_op,
summary_op=None,
saver=None,
global_step=global_step,
stop_grace_secs=300,
save_model_secs=10)
with sv.managed_session(server.target) as sess:
tf_feed = TFNode.DataFeed(ctx.mgr, True)
step = 0

while not sv.should_stop() and not tf_feed.should_stop() and step < 100:
data = tf_feed.next_batch(params["batch_size"])
batch_data = feed_dict(data)
step += 1
_, x, g = sess.run([train_step, xent, global_step], feed_dict={input_x: batch_data})
print("global_step:{} xent:{}".format(g, x))

if sv.should_stop() or step >= args.steps:
tf_feed.terminate()
sv.stop()

def train(input_x, init_op, train_step, xent, global_step, summ):

with tf.Session() as sess:
sess.run(init_op)
## for i in range(echo)
for data in _read_data(max_records=params["batch_size"]):
batch_data = feed_dict(data)
_, x, g = sess.run([train_step, xent, global_step], feed_dict={input_x: batch_data})
print("global_step:{} xent:{}".format(x, g))

if clusterMode and ctx.job_name == "ps":
server.join()
elif clusterMode and ctx.job_name == "worker":
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % ctx.task_index,
cluster=cluster)):
input_x, init_op, train_step, xent, global_step, summ = build_graph()
train_with_cluster(input_x, init_op, train_step, xent, global_step, summ)
else:
input_x, init_op, train_step, xent, global_step, summ = build_graph()
train(input_x, init_op, train_step, xent, global_step, summ)


input_col = "text"
output_col = "sentence_matrix"

session = SparkSession.builder.master("spark://allwefantasy:7077").appName("test").getOrCreate()
documentDF = session.createDataFrame([
("Hi I heard about Spark", 1),
("I wish Java could use case classes", 0),
("Logistic regression models are neat", 2)
], ["text", "preds"])

# transform text column to sentence_matrix column which contains 2-D array.
transformer = TFTextTransformer(
inputCol=input_col, outputCol=output_col, embeddingSize=100, sequenceLength=64)

df = transformer.transform(documentDF)

# create a estimator to training where map_fun contains tensorflow's code
estimator = TFTextFileEstimator(inputCol="sentence_matrix", outputCol="sentence_matrix", labelCol="preds",
fitParam=[{"epochs": 1, "cluster_size": 2, "batch_size": 1, "model": "/tmp/model"}],
runningMode="TFoS",
mapFnParam=map_fun)
estimator.fit(df).collect()
5 changes: 5 additions & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ pygments>=2.2.0
tensorflow==1.3.0
pandas>=0.19.1
six>=1.10.0
kafka-python>=1.3.5
tensorflowonspark>=1.0.5
tensorflow-tensorboard>=0.1.6


4 changes: 2 additions & 2 deletions python/sparkdl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

__all__ = [
'imageSchema', 'imageType', 'readImages',
'TFImageTransformer',
'TFImageTransformer', 'TFTextTransformer',
'DeepImagePredictor', 'DeepImageFeaturizer',
'KerasImageFileTransformer',
'KerasImageFileTransformer', 'TFTextFileEstimator',
'imageInputPlaceholder']
Loading