Skip to content

Commit

Permalink
Updated documentation and examples for BTF
Browse files Browse the repository at this point in the history
  • Loading branch information
orso82 committed Mar 13, 2019
1 parent cf27402 commit 3b6029e
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 109 deletions.
11 changes: 5 additions & 6 deletions brainfusetf.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ int main(int argc, char**argv){

int i;
char btf_sendline[65507];
char model[256]="tglfnn/models/nn_SAT1_mb_1024_abs_reg_common_stair2x2x6.pb";
double input[21]={5.45338e-01, 4.17925e-02, 7.21986e-03, 1.24956e-01, -1.37899e-01, 1.58491e+00, -4.20101e-03, 1.55640e+00, 8.36932e+00, 1.02569e+00, 2.05921e+00, -4.45231e-01, 3.00670e+00, 2.06023e+00, 2.38220e+00, 7.66336e-01, 3.20824e-01, 1.14110e+00, 3.21049e-01, 3.36619e-01, 1.87324e+00};
double output[6];
char model[256]="eped1nn/models/EPED_mb_128_pow_norm_common_30x10.pb";
double input[10]={0.5778, 1.8034, 2.0995, 0.2075, 1.1621, 1.8017, 2, 4.0101, 1.6984, 1.4429};
double output[18];

btf_run(model,input,sizeof(input)/sizeof(double),output,sizeof(output)/sizeof(double));
for (i=0; i<6; i++)
btf_run(model, input, sizeof(input)/sizeof(double), output, sizeof(output)/sizeof(double));
for (i=0; i<sizeof(output)/sizeof(output[0]); i++)
printf("%f\n",output[i]);
}

172 changes: 86 additions & 86 deletions brainfusetf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,25 @@
from numpy import *
import struct

default_serve_port=8883
default_serve_port = 8883

#=======================
# =======================
# helper functions
#=======================
# =======================
def send_data(sock, path, data):
payload='{path}&{shape}&[{data}]'.format(path=path,
shape=data.shape,
data=','.join(map(lambda x:'%f'%x,data.flatten())))
payload = '{path}&{shape}&[{data}]'.format(path=path,
shape=data.shape,
data=','.join(map(lambda x: '%f' % x, data.flatten())))
return send_msg(sock, payload)

def send_ask_info(sock, path):
payload='{path}&(?,?)'.format(path=path)
payload = '{path}&(?,?)'.format(path=path)
return send_msg(sock, payload)

def send_info(sock, path, x_names, y_names):
payload='{path}&{x_names}&{y_names}'.format(path=path,
x_names=repr(map(str,x_names)),
y_names=repr(map(str,y_names)))
payload = '{path}&{x_names}&{y_names}'.format(path=path,
x_names=repr(map(str, x_names)),
y_names=repr(map(str, y_names)))
return send_msg(sock, payload)

def send_msg(sock, msg):
Expand All @@ -36,14 +36,13 @@ def send_msg(sock, msg):

def parse_data(data):
try:
path,shape,data=data.split('&')
path, shape, data = data.split('&')
except Exception:
print(data)
raise
return path, np.reshape(eval(data),eval(shape))
raise(Exception('Malformed input data'))
return path, np.reshape(eval(data), eval(shape))

def parse_info(data):
path,x_names,y_names=data.split('&')
path, x_names, y_names = data.split('&')
return path, eval(x_names), eval(y_names)

def recv_msg(sock):
Expand All @@ -65,15 +64,15 @@ def recvall(sock, n):
data += packet
return data

#=======================
# =======================
# client
#=======================
# =======================
class btf_connect(object):
def __init__(self, path, host=None, port=None):
if host is None:
host=os.environ.get('BTF_HOST','gadb-harvest.ddns.net')
host = os.environ.get('BTF_HOST', 'gadb-harvest.ddns.net')
if port is None:
port=int(os.environ.get('BTF_PORT',default_serve_port))
port = int(os.environ.get('BTF_PORT', default_serve_port))
self.host = host
self.port = port
self.path = path
Expand All @@ -84,7 +83,7 @@ def __enter__(self):
try:
self.sock.connect((self.host, self.port))
except:
print('HOST:%s PORT:%s'%(self.host,self.port))
print('HOST:%s PORT:%s' % (self.host, self.port))
return self

def __exit__(self, exec_type, exec_value, exec_tb):
Expand All @@ -99,17 +98,17 @@ def info(self):
return self.x_names, self.y_names

def run(self, input):
if isinstance(input,dict):
if not hasattr(self,'x_names') or not hasattr(self,'y_names'):
with btf_connect(host=self.host,port=self.port,path=self.path) as btf:
self.x_names,self.y_names=btf.info()
if isinstance(input, dict):
if not hasattr(self, 'x_names') or not hasattr(self, 'y_names'):
with btf_connect(host=self.host, port=self.port, path=self.path) as btf:
self.x_names, self.y_names = btf.info()
print(self.x_names)
print(self.y_names)
input=np.array([input[name] for name in self.x_names]).T
input = np.array([input[name] for name in self.x_names]).T
send_data(self.sock, self.path, input)
path,output=parse_data(recv_msg(self.sock))
if isinstance(input,dict):
output={name:output[:,k] for k,name in enumerate(self.y_names)}
path, output = parse_data(recv_msg(self.sock))
if isinstance(input, dict):
output = {name: output[:, k] for k, name in enumerate(self.y_names)}
return output

def activateNets(nets, dB):
Expand All @@ -120,38 +119,39 @@ def activateNets(nets, dB):
:return: tuple with (out,sut,targets,nets,out_)
'''
if isinstance(nets,basestring):
nets={k:OMFITpath(file) for k,file in enumerate(glob.glob(nets))}
elif not isinstance(nets,(list,tuple)):
nets={0:nets}
net=nets.values()[0]
if isinstance(nets, basestring):
nets = {k: OMFITpath(file) for k, file in enumerate(glob.glob(nets))}
elif not isinstance(nets, (list, tuple)):
nets = {0: nets}
net = nets.values()[0]

with btf_connect(path=net.filename) as btf:
inputNames,outputNames=btf.info()
targets=array([dB[k] for k in outputNames]).T
inputNames, outputNames = btf.info()
targets = array([dB[k] for k in outputNames]).T

out_=empty((len(atleast_1d(dB.values()[0])),len(outputNames),len(nets)))
for k,n in enumerate(nets):
out_ = empty((len(atleast_1d(dB.values()[0])), len(outputNames), len(nets)))
for k, n in enumerate(nets):
with btf_connect(path=net.filename) as btf:
out_[:,:,k]=btf.run(dB)
out=mean(out_,-1)
sut=std(out_,-1)
return out,sut,targets,nets,out_
out_[:, :, k] = btf.run(dB)
out = mean(out_, -1)
sut = std(out_, -1)
return out, sut, targets, nets, out_

#=======================
# =======================
# server
#=======================
# =======================
if __name__ == "__main__":
import tensorflow as tf
from tensorflow.python.platform import gfile

__file__=os.path.abspath(__file__)
__file__ = os.path.abspath(__file__)

serve_port=default_serve_port
if len(sys.argv)>1:
serve_port=int(sys.argv[1])
serve_port = default_serve_port
if len(sys.argv) > 1:
serve_port = int(sys.argv[1])

models = {}

models={}
def activate(path, input):
'''
high level function that handles models buffering
Expand All @@ -160,21 +160,21 @@ def activate(path, input):
:return: output array or xnames,ynames
'''
if not path.startswith(os.sep):
path=os.path.split(__file__)[0]+os.sep+path
path=os.path.realpath(path)
path = os.path.split(__file__)[0] + os.sep + path
path = os.path.realpath(path)
if path not in models:
print('Loading %s'%path)
models[path]=model(path=path)
print('Loading %s' % path)
models[path] = model(path=path)
else:
print('mtime %f'%os.path.getmtime(path))
print('ltime %f'%models[path].load_time)
if os.path.getmtime(path)>models[path].load_time:
raise(Exception('Model reload functionality not implemented')) #somehow the model does not get updated?
#print('Updating %s'%path)
#models[path].close()
#models[path]=model(path=path)
print('mtime %f' % os.path.getmtime(path))
print('ltime %f' % models[path].load_time)
if os.path.getmtime(path) > models[path].load_time:
raise (Exception('Model reload functionality not implemented')) # somehow the model does not get updated?
# print('Updating %s'%path)
# models[path].close()
# models[path]=model(path=path)
if input is None:
return models[path].x_names,models[path].y_names
return models[path].x_names, models[path].y_names
else:
return models[path].activate(input)

Expand All @@ -183,53 +183,53 @@ def __init__(self, target='', graph=None, config=None, path=None):
tf.Session.__init__(self, target=target, graph=graph, config=config)
self.__enter__()
if not path.startswith(os.sep):
path=os.path.split(__file__)[0]+os.sep+path
path=os.path.realpath(path)
self.path=path
self.load_time=os.path.getmtime(self.path)
path = os.path.split(__file__)[0] + os.sep + path
path = os.path.realpath(path)
self.path = path
self.load_time = os.path.getmtime(self.path)
with gfile.FastGFile(self.path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

self.y, self.x_names, self.y_names = tf.import_graph_def(graph_def, return_elements=['unnormalized_nn/y:0',
'x_names:0',
'y_names:0'], name='')
self.x_names=self.x_names.eval()
self.y_names=self.y_names.eval()
self.x_names = self.x_names.eval()
self.y_names = self.y_names.eval()

def activate(self,input):
def activate(self, input):
'''
:param input: input array
:return: output array
'''
print('Serving %s'%self.path)
print('Serving %s' % self.path)
return self.run(self.y, feed_dict={'x:0': input})

class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):

def handle(self):
msg=recv_msg(self.request)
msg = recv_msg(self.request)
if msg is None:
print("{}: {}".format(self.client_address[0],'-- no message --'))
print("{}: {}".format(self.client_address[0], '-- no message --'))
return
print("{}: message length {}".format(self.client_address[0],len(msg)))
query=msg.split('&')[1]
#respond to info request
if query=='(?,?)':
print("{}: message length {}".format(self.client_address[0], len(msg)))
query = msg.split('&')[1]
# respond to info request
if query == '(?,?)':
print('INFO-MODE')
path=msg.split('&')[0]
x_names,y_names=activate(path=path,input=None)
send_info(self.request,path,x_names,y_names)
#respond to data request
path = msg.split('&')[0]
x_names, y_names = activate(path=path, input=None)
send_info(self.request, path, x_names, y_names)
# respond to data request
else:
print('DATA-MODE: serve starts')
while True:
try:
if msg is not None:
path,input=parse_data(msg)
output=activate(path=path,input=input)
send_data(self.request,path,output)
msg=recv_msg(self.request)
path, input = parse_data(msg)
output = activate(path=path, input=input)
send_data(self.request, path, output)
msg = recv_msg(self.request)
except Exception as _excp:
if 'Broken pipe' in repr(_excp) or 'Connection reset by peer' in repr(_excp) or 'Protocol wrong type for socket' in repr(_excp):
print('DATA-MODE: serve ends')
Expand All @@ -242,11 +242,11 @@ class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer):

server = ThreadedTCPServer(('0.0.0.0', serve_port), ThreadedTCPRequestHandler)
print('-- BASH shell --')
print("export BTF_HOST=%s"%socket.gethostname())
print("export BTF_PORT=%d"%serve_port)
print("export BTF_HOST=%s" % socket.gethostname())
print("export BTF_PORT=%d" % serve_port)
print('-- TCSH shell --')
print("setenv BTF_HOST %s"%socket.gethostname())
print("setenv BTF_PORT %d"%serve_port)
print("setenv BTF_HOST %s" % socket.gethostname())
print("setenv BTF_PORT %d" % serve_port)
try:
server.serve_forever()
except (KeyboardInterrupt, SystemExit):
Expand Down
1 change: 0 additions & 1 deletion brainfusetf_lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ int btf_run(char *model, double *input, int input_len, double *output, int outpu
sprintf(message1,"%s%g,",message1,*(input+i));
}
sprintf(message1,"%s%g]",message1,*(input+input_len-1));

//send request
for(i = 0; i < 10; i++){
ack=0;
Expand Down
36 changes: 31 additions & 5 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,26 @@ This repository contains NN models for EPED1-NN, TGLF-NN and NEOjbs-NN

Refer to the `readme.md` files in the subfolders for more info.

Online widget
-------------
Two libraries have been used to train the NN models.

An online widget for exploring the models in this repository is available at http://gadb-harvest.ddns.net
1. the FANN library is a NN library that allows little flexibility in the
definition onf the models, but has the advantage of being very portable,
and have bindings to many languages

Installation
------------
2. the TENSORFLOW library allows great flexibility in building NN models,
and gives access to the most modern machine learning techniques and algorythmes,
but has the disadvantage of being more difficult to deploy the trained models
outside of Python. Yet running TENSORFLOW models from C, MATLAB, FORTRAN, should
be possible. For this purpose a client-server library that is capable of servicing
these models across the web over TCP/IP was developed. The client side of the library
is written in pure C (with available FORTRAN and Python interfaces) such that no
external library dependency is required. This approach lifts the cumbersome
installation requirements for doing inference with Tensorflow models from C.

We refer to as `BRAINFUSE` to the set of tools in this repository that are used to run these models.

FANN models
-----------

Install the FANN c library:

Expand All @@ -30,3 +43,16 @@ Set in your .login file:

Run `./compile.sh` script


TENSORFLOW models
-----------------

To start one can use the public `brainfusetf` server `gadb-harvest.ddns.net`
to serve trained models.

A python can be run with: `tf_client_server_test.py`

A C example can be run with:

make brainfusetf_run.exe
brainfusetf_run.exe
Loading

0 comments on commit 3b6029e

Please sign in to comment.