Skip to content

Commit

Permalink
Code styling typos and error behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
Lokhozt committed Sep 20, 2022
1 parent 78af626 commit 210ce88
Show file tree
Hide file tree
Showing 12 changed files with 169 additions and 123 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
start_container.sh
start_container.sh
1 change: 0 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,3 @@ HEALTHCHECK CMD ./healthcheck.sh

ENV TEMP=/usr/src/app/tmp
ENTRYPOINT ["./docker-entrypoint.sh"]
CMD ["serve"]
13 changes: 13 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.DEFAULT_GOAL := help

target_dirs := punctuation http_server celery_app

help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'

style: ## update code style.
black -l 100 ${target_dirs}
isort ${target_dirs}

lint: ## run pylint linter.
pylint ${target_dirs}
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ The punctuation service relies on a BERT model.
We provide some models on [dl.linto.ai](https://dl.linto.ai/downloads/model-distribution/punctuation_models/).

### Docker
The transcription service requires docker up and running.
The punctuation service requires docker up and running.

### (micro-service) Service broker
The punctuation only entry point in job mode are tasks posted on a message broker. Supported message broker are RabbitMQ, Redis, Amazon SQS.

## Deploy linto-platform-stt
linto-platform-stt can be deployed two ways:
## Deploy linto-platform-punctuation
linto-platform-punctuation can be deployed two ways:
* As a standalone punctuation service through an HTTP API.
* As a micro-service connected to a message broker.

Expand Down Expand Up @@ -79,7 +79,7 @@ linto-platform-punctuation:latest
| MODEL_PATH | Your localy available model (.mar) | /my/path/to/models/punctuation.mar |
| SERVICES_BROKER | Service broker uri | redis://my_redis_broker:6379 |
| BROKER_PASS | Service broker password (Leave empty if there is no password) | my_password |
| LANGUAGE | Transcription language | en-US |
| LANGUAGE | Punctuation language | en-US |
| CONCURRENCY | Number of worker (1 worker = 1 cpu) | [ 1 -> numberOfCPU] |

## Usages
Expand All @@ -96,7 +96,7 @@ Returns "1" if healthcheck passes.

#### /punctuation

Transcription API
Punctuation API

* Method: POST
* Response content: text/plain or application/json
Expand Down Expand Up @@ -125,7 +125,7 @@ The /docs route offers a OpenAPI/swagger interface.

### Through the message broker

STT-Worker accepts requests with the following arguments:
Punctuation-Worker accepts requests with the following arguments:
```file_path: str, with_metadata: bool```

* <ins>text</ins>: (str or list) A sentence or a list of sentences.
Expand Down
6 changes: 6 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# 1.0.1
- Changes behavior on prediction error from failed to ignore.
- Adds makefile for code styling (PEP 8)
- Fixes typos.
- Changes code style (PEP 8)

# 1.0.0
- Punctuation service.
- HTTP or Celery serving.
24 changes: 13 additions & 11 deletions celery_app/celeryapp.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
import os

from celery import Celery

from punctuation import logger

celery = Celery(__name__, include=['celery_app.tasks'])
celery = Celery(__name__, include=["celery_app.tasks"])
service_name = os.environ.get("SERVICE_NAME")
broker_url = os.environ.get("SERVICES_BROKER")
if os.environ.get("BROKER_PASS", False):
components = broker_url.split('//')
components = broker_url.split("//")
broker_url = f'{components[0]}//:{os.environ.get("BROKER_PASS")}@{components[1]}'
celery.conf.broker_url = "{}/0".format(broker_url)
celery.conf.result_backend = "{}/1".format(broker_url)
celery.conf.update(
result_expires=3600,
task_acks_late=True,
task_track_started = True)
celery.conf.broker_url = f"{broker_url}/0"
celery.conf.result_backend = f"{broker_url}/1"
celery.conf.update(result_expires=3600, task_acks_late=True, task_track_started=True)

# Queues
language = os.environ.get("LANGUAGE")
celery.conf.update(
{'task_routes': {
'punctuation_task' : {'queue': f"punctuation_{language}"},}
{
"task_routes": {
"punctuation_task": {"queue": f"punctuation_{language}"},
}
}
)
logger.info("Celery configured for broker located at {} with service name {}".format(broker_url, service_name))
logger.info(
f"Celery configured for broker located at {broker_url} with service name {service_name}"
)
44 changes: 27 additions & 17 deletions celery_app/tasks.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,47 @@
import os
import requests
import json
from celery_app.celeryapp import celery
from typing import Union

import requests

from celery_app.celeryapp import celery


@celery.task(name="punctuation_task", bind=True)
def punctuation_task(self, text: Union[str, list]):
""" punctuation_task do a synchronous call to the punctuation serving API """
"""punctuation_task do a synchronous call to the punctuation serving API"""
self.update_state(state="STARTED")
# Fetch model name
try:
result = requests.get("http://localhost:8081/models",
headers={"accept": "application/json",},)
result = requests.get(
"http://localhost:8081/models",
headers={
"accept": "application/json",
},
)
models = json.loads(result.text)
model_name = models["models"][0]["modelName"]
except:
raise Exception("Failed to fetch model name")
except Exception as error:
raise Exception("Failed to fetch model name") from error

if isinstance(text, str):
sentences = [text]
else:
sentences = text
punctuated_sentences = []
for i, sentence in enumerate(sentences):
self.update_state(state="STARTED", meta={"current": i, "total": len(sentences)})

result = requests.post("http://localhost:8080/predictions/{}".format(model_name),
headers={'content-type': 'application/octet-stream'},
data=sentence.strip().encode('utf-8'))

result = requests.post(
f"http://localhost:8080/predictions/{model_name}",
headers={"content-type": "application/octet-stream"},
data=sentence.strip().encode("utf-8"),
)
if result.status_code == 200:
punctuated_sentence = result.text
punctuated_sentence = punctuated_sentence[0].upper() + punctuated_sentence[1:]
punctuated_sentences.append(punctuated_sentence)
else:
raise Exception(result.text)
return punctuated_sentences[0] if len(punctuated_sentences) == 1 else punctuated_sentences
print("Failed to predict punctuation on sentence: >{sentence}<")
punctuated_sentence = sentence
punctuated_sentence = punctuated_sentence[0].upper() + punctuated_sentence[1:]
punctuated_sentences.append(punctuated_sentence)

return punctuated_sentences[0] if len(punctuated_sentences) == 1 else punctuated_sentences
62 changes: 28 additions & 34 deletions http_server/confparser.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,45 @@
import os
import argparse
import os

__all__ = ["createParser"]


def createParser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

# SERVICE
parser.add_argument(
'--service_name',
"--service_name",
type=str,
help='Service Name',
default=os.environ.get('SERVICE_NAME', 'punctuation'))

#GUNICORN
parser.add_argument(
'--service_port',
type=int,
help='Service port',
default=80)
help="Service Name",
default=os.environ.get("SERVICE_NAME", "punctuation"),
)

# GUNICORN
parser.add_argument("--service_port", type=int, help="Service port", default=80)
parser.add_argument(
'--workers',
"--workers",
type=int,
help="Number of Gunicorn workers (default=CONCURRENCY + 1)",
default=int(os.environ.get('CONCURRENCY', 1)) + 1)

#SWAGGER
parser.add_argument(
'--swagger_url',
type=str,
help='Swagger interface url',
default='/docs')
default=int(os.environ.get("CONCURRENCY", 1)) + 1,
)

# SWAGGER
parser.add_argument("--swagger_url", type=str, help="Swagger interface url", default="/docs")
parser.add_argument(
'--swagger_prefix',
"--swagger_prefix",
type=str,
help='Swagger prefix',
default=os.environ.get('SWAGGER_PREFIX', ''))
help="Swagger prefix",
default=os.environ.get("SWAGGER_PREFIX", ""),
)
parser.add_argument(
'--swagger_path',
"--swagger_path",
type=str,
help='Swagger file path',
default=os.environ.get('SWAGGER_PATH', '/usr/src/app/document/swagger.yml'))

#MISC
parser.add_argument(
'--debug',
action='store_true',
help='Display debug logs')
help="Swagger file path",
default=os.environ.get("SWAGGER_PATH", "/usr/src/app/document/swagger.yml"),
)

# MISC
parser.add_argument("--debug", action="store_true", help="Display debug logs")

return parser
return parser
Loading

0 comments on commit 210ce88

Please sign in to comment.