-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #236 from isi-nlp/130-restapi
Add REST API around decoder
- Loading branch information
Showing
7 changed files
with
168 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
ARG experiment_dir | ||
FROM python:3.7 | ||
# RUN echo ${experiment_dir} | ||
# RUN echo "hello world" | ||
# COPY ${experiment_dir} /experiment/ | ||
COPY ./experiments/sample-exp/ /experiment/ | ||
COPY . /rtg | ||
WORKDIR /rtg | ||
RUN pip install -e ./ | ||
RUN python -m rtg.pipeline /experiment/ | ||
CMD python -m rtg.deploy /experiment/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,5 +17,7 @@ include::clitools.adoc[] | |
|
||
include::environ.adoc[] | ||
|
||
include::serve.adoc[] | ||
|
||
include::develop.adoc[] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
|
||
== RTG Serve | ||
|
||
RTG model can be served using Flask Server. | ||
|
||
|
||
[source,commandline] | ||
---- | ||
$ python -m rtg.serve -h # rtg-serve | ||
[07-13 22:38:01] p49095 {__init__:53} INFO - rtg v0.3.1 from /Users/tg/work/me/rtg | ||
usage: rtg.serve [-h] [-sc] [-p PORT] [-ho HOST] [-msl MAX_SRC_LEN] exp_dir | ||
deploy a model to a RESTful react server | ||
positional arguments: | ||
exp_dir Experiment directory | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
-sc, --skip-check Skip Checking whether the experiment dir is prepared | ||
and trained (default: False) | ||
-p PORT, --port PORT port to run server on (default: 6060) | ||
-ho HOST, --host HOST | ||
Host address to bind. (default: 0.0.0.0) | ||
-msl MAX_SRC_LEN, --max-src-len MAX_SRC_LEN | ||
max source len; longer seqs will be truncated | ||
(default: None) | ||
---- | ||
|
||
|
||
To launch a service for `runs/001-tfm` experiment | ||
|
||
`python -m rtg.serve -sc runs/001-tfm` | ||
|
||
It prints : | ||
`* Running on http://0.0.0.0:6060/ (Press CTRL+C to quit)` | ||
|
||
Currently only `/translate` API is supported. It accepts both `GET` with query params and `POST` with form params. | ||
|
||
NOTE: batch decoding is yet to be supported. Current decoder decodes only one sentence at a time. | ||
|
||
An example POST request: | ||
---- | ||
curl --data "source=Comment allez-vous?" --data "source=Bonne journée" http://localhost:6060/translate | ||
---- | ||
[source,json] | ||
---- | ||
{ | ||
"source": [ | ||
"Comment allez-vous?", | ||
"Bonne journée" | ||
], | ||
"translation": [ | ||
"How are you?", | ||
"Have a nice day" | ||
] | ||
} | ||
---- | ||
You can also request like GET method as `http://localhost:6060/translate?source=text1&source=text2` | ||
after properly URL encoding the `text1` `text2`. This should only be used for quick testing in your web browser. | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#!/usr/bin/env python | ||
""" | ||
Serves an RTG model using Flask HTTP server | ||
""" | ||
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter | ||
|
||
from flask import Flask, request, jsonify | ||
import torch | ||
|
||
from rtg import TranslationExperiment as Experiment, log | ||
from rtg.module.decoder import Decoder | ||
|
||
|
||
def prepare_decoder(cli_args): | ||
# No grads required for decode | ||
torch.set_grad_enabled(False) | ||
exp = Experiment(cli_args.pop("exp_dir"), read_only=True) | ||
dec_args = exp.config.get("decoder") or exp.config["tester"].get("decoder", {}) | ||
validate_args(cli_args, dec_args, exp) | ||
decoder = Decoder.new(exp, ensemble=dec_args.pop("ensemble", 1)) | ||
return decoder, dec_args | ||
|
||
|
||
def attach_translate_route(app, decoder, dec_args): | ||
|
||
app.config['JSON_AS_ASCII'] = False | ||
|
||
@app.route("/translate", methods=["POST", "GET"]) | ||
def translate(): | ||
if request.method not in ("POST", "GET"): | ||
return "GET and POST are supported", 400 | ||
if request.method == 'GET': | ||
sources = request.args.getlist("source", None) | ||
else: | ||
sources = request.form.getlist("source", None) | ||
if not sources: | ||
return "Please provide parameter 'source'", 400 | ||
|
||
translations = [] | ||
for source in sources: | ||
translated = decoder.decode_sentence(source, **dec_args)[0][1] | ||
translations.append(translated) | ||
res = dict(source=sources, translation=translations) | ||
return jsonify(res) | ||
|
||
|
||
def validate_args(cli_args, conf_args, exp: Experiment): | ||
if not cli_args.pop("skip_check"): # if --skip-check is not requested | ||
assert exp.has_prepared(), (f'Experiment dir {exp.work_dir} is not ready to train.' | ||
f' Please run "prep" sub task') | ||
assert exp.has_trained(), (f"Experiment dir {exp.work_dir} is not ready to decode." | ||
f" Please run 'train' sub task or --skip-check to ignore this") | ||
|
||
def parse_args(): | ||
parser = ArgumentParser( | ||
prog="rtg.serve", | ||
description="deploy a model to a RESTful react server", | ||
formatter_class=ArgumentDefaultsHelpFormatter, | ||
) | ||
parser.add_argument("exp_dir", help="Experiment directory", type=str) | ||
parser.add_argument("-sc", "--skip-check", action="store_true", | ||
help="Skip Checking whether the experiment dir is prepared and trained") | ||
parser.add_argument("-p", "--port", type=int, help="port to run server on", default=6060) | ||
parser.add_argument("-ho", "--host", help="Host address to bind.", default='0.0.0.0') | ||
parser.add_argument("-msl", "--max-src-len", type=int, | ||
help="max source len; longer seqs will be truncated") | ||
args = vars(parser.parse_args()) | ||
return args | ||
|
||
|
||
def main(): | ||
cli_args = parse_args() | ||
decoder, dec_args = prepare_decoder(cli_args) | ||
app = Flask(__name__) | ||
#CORS(app) # TODO: insecure | ||
app.debug = True | ||
attach_translate_route(app, decoder, dec_args) | ||
app.run(port=cli_args["port"], host=cli_args["host"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters