Skip to content

Commit

Permalink
Merge pull request #922 from weni-ai/feature/zeroshot-5.0.0
Browse files Browse the repository at this point in the history
Feature/zeroshot 5.0.0
  • Loading branch information
johncordeiro authored Aug 29, 2024
2 parents 9e9da85 + a908464 commit 00f223d
Show file tree
Hide file tree
Showing 7 changed files with 2,025 additions and 2,103 deletions.
43 changes: 27 additions & 16 deletions bothub/api/v2/zeroshot/usecases/format_classification.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
import unicodedata
import re

from django.conf import settings
from .format_prompt import FormatPrompt


class FormatClassification:
const_none_class_id = -1
const_none_class = {"por": "Nenhuma", "eng": "None", "spa": "Ninguna"}
const_none_class = "None"

def __init__(self, language: str, classification_data: dict):
self.language = language
def __init__(self, classification_data: dict):
self.model_backend = settings.ZEROSHOT_MODEL_BACKEND
self.classification_data = classification_data

def get_classification(self, zeroshot_data):
if self.model_backend == "runpod":
return self._get_runpod_classification(zeroshot_data)
elif self.model_backend == "bedrock":
return self._get_bedrock_classification(zeroshot_data)
else:
raise ValueError(f"Unsupported model backend: {self.model_backend}")

def _get_data_none_class(self):
return self.const_none_class[self.language]
return self.const_none_class

def _get_number_from_output(self, output):
output_result = self.const_none_class_id
Expand All @@ -23,27 +32,29 @@ def _get_number_from_output(self, output):

return output_result

def _get_final_output(self):
output_text = self.classification_data.get("output").get("text")
response_text = output_text[0] if output_text else self._get_data_none_class()
def _get_formatted_output(self, output_text, zeroshot_data):
classification = {"other": True, "classification": self._get_data_none_class()}
response_text = output_text if output_text else self._get_data_none_class()

response_prepared = response_text.lower()
response_prepared = response_text.strip().strip(".").strip("\n").strip("'")

output = self._get_number_from_output(response_prepared)
return output


def get_classify(self, zeroshot_data):
classify = {"other": True, "classification": self._get_data_none_class()}
output = self._get_final_output()

if output or output != self.const_none_class_id:
all_classes = zeroshot_data.get("options")
for class_obj in all_classes:
if output == str(class_obj.get("id")):
classify["other"] = False
classify["classification"] = class_obj.get("class")
classification["other"] = False
classification["classification"] = class_obj.get("class")
break

return classify
return classification

def _get_runpod_classification(self, zeroshot_data):
output_text = self.classification_data.get("output")[0].get("choices")[0].get("tokens")[0]
return self._get_formatted_output(output_text, zeroshot_data)

def _get_bedrock_classification(self, zeroshot_data):
output_text = self.classification_data.get("outputs")[0].get("text").strip()
return self._get_formatted_output(output_text, zeroshot_data)
43 changes: 13 additions & 30 deletions bothub/api/v2/zeroshot/usecases/format_prompt.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,39 @@
BASE_PROMPT = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{context_description} {context}
{reflection}
# {classes_title}
{all_classes}<|eot_id|><|start_header_id|>user<|end_header_id|>
{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>
{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""

class FormatPrompt:
const_prompt_data = {
"por": {
"context_description": "Você é muito especialista em classificar a frase do usuário em um chatbot sobre:",
"reflection": "Pare, pense bem e responda com APENAS UM ÚNICO `id` da classe que melhor represente a intenção para a frase do usuário de acordo com a análise de seu contexto, responda APENAS com o `id` da classe só se você tiver muita certeza e não explique o motivo. Na ausência, falta de informações ou caso a frase do usuário não se enquadre em nenhuma classe, classifique como \"-1\".",
"classes_title": "Essas são as Classes com seus Id e Contexto:"
},
"eng": {
"context_description": "You are very expert in classifying the user sentence in a chatbot about:",
"reflection": "Stop, think carefully, and respond with ONLY ONE SINGLE `id` of the class that best represents the intention for the user's sentence according to the analysis of its context, respond ONLY with the `id` of the class if you are very sure and do not explain the reason. In the absence, lack of information, or if the user's sentence does not fit into any class, classify as \"-1\".",
"classes_title": "These are the Classes and its Context:"
},
"spa": {
"context_description": "Eres muy experto en clasificar la frase del usuario en un chatbot sobre:",
"reflection": "Deténgase, piense bien y responda con SOLO UN ÚNICO `id` de la clase que mejor represente la intención para la frase del usuario de acuerdo con el análisis de su contexto, responda SOLO con el `id` de la clase si está muy seguro y no explique el motivo. En ausencia, falta de información o en caso de que la frase del usuario no se ajuste a ninguna clase, clasifique como \"-1\".",
"classes_title": "Estas son las Clases con sus Id y Contexto:"
}
"system_prompt": "Task: Classify the 'User' message within a chatbot about: {context}. Carefully consider the context and respond with ONLY ONE tag of the class that best represents the intent of the 'User' message with the below categories.\n\n<BEGIN CONTENT CATEGORIES>\n{classes_formatted}\n<END CONTENT CATEGORIES>",
"question": "{input}",
}

def generate_prompt(self, language: str, zeroshot_data: dict):
translated_text = self.const_prompt_data[language]
context = zeroshot_data.get("context")
input = zeroshot_data.get("text")
all_classes = self.setup_ids_on_classes(zeroshot_data.get("options"))
classes_formatted = self.format_classes(all_classes)

prompt = BASE_PROMPT.format(context_description=translated_text.get("context_description"),
reflection=translated_text.get("reflection"),
classes_title=translated_text.get("classes_title"),
context=context,
all_classes=all_classes,
input=input)
print(f"prompt: {prompt}")
system_prompt = self.const_prompt_data["system_prompt"].format(context=context, classes_formatted=classes_formatted)
question = self.const_prompt_data["question"].format(input=input)

prompt = BASE_PROMPT.format(system_prompt=system_prompt, input=question)
return prompt

def setup_ids_on_classes(self, all_classes):
for index, class_obj in enumerate(all_classes):
id = index + 1
class_obj["id"] = id

return all_classes

def format_classes(self, all_classes):
classes_formatted = '\n'.join([f"A{mclass['id']}: {mclass['class']} - {mclass['context']}" for index, mclass in enumerate(all_classes)])
classes_formatted += f"\nA{len(all_classes)+1}: none - if there is insufficient information or if the User message doesn't fit any class"
return classes_formatted

def get_default_language(self):
return "por"
return "por"
108 changes: 108 additions & 0 deletions bothub/api/v2/zeroshot/usecases/invoke_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import json
import requests
import boto3

from django.conf import settings
from .format_classification import FormatClassification
from .format_prompt import FormatPrompt


class InvokeModel:

def __init__(self, zeroshot_data):
self.model_backend = settings.ZEROSHOT_MODEL_BACKEND
self.zeroshot_data = zeroshot_data

def invoke(self):
prompt = self._get_prompt(self.zeroshot_data)

if self.model_backend == "runpod":
return self._invoke_runpod(prompt)
elif self.model_backend == "bedrock":
return self._invoke_bedrock(prompt)
else:
raise ValueError(f"Unsupported model backend: {self.model_backend}")

def _get_prompt(self, zeroshot_data):
prompt_formatter = FormatPrompt()

language = zeroshot_data.get("language", prompt_formatter.get_default_language())
return prompt_formatter.generate_prompt(language, zeroshot_data)

def _invoke_runpod(self, prompt):
payload = json.dumps({
"input": {
"prompt": prompt,
"sampling_params": {
"max_tokens": settings.ZEROSHOT_MAX_TOKENS,
"n": settings.ZEROSHOT_N,
"top_p": settings.ZEROSHOT_TOP_P,
"top_k": settings.ZEROSHOT_TOP_K,
"temperature": settings.ZEROSHOT_TEMPERATURE,
"stop": settings.ZEROSHOT_STOP
}

}
})

headers = {
"Content-Type": "application/json; charset: utf-8",
"Authorization": f"Bearer {settings.ZEROSHOT_TOKEN}",
}
response_nlp = None
response = {"output": {}}

url = settings.ZEROSHOT_BASE_NLP_URL
if len(settings.ZEROSHOT_SUFFIX) > 0:
url += settings.ZEROSHOT_SUFFIX
response_nlp = requests.post(
headers=headers,
url=url,
data=payload
)

if response_nlp.status_code == 200:
classification = response_nlp.json()
classification_formatter = FormatClassification(classification)
formatted_classification = classification_formatter.get_classification(self.zeroshot_data)

response["output"] = formatted_classification

return response

def _invoke_bedrock(self, prompt):
response = {"output": {}}

session = boto3.Session(
aws_access_key_id=settings.ZEROSHOT_BEDROCK_AWS_KEY,
aws_secret_access_key=settings.ZEROSHOT_BEDROCK_AWS_SECRET,
region_name=settings.ZEROSHOT_BEDROCK_AWS_REGION
)

bedrock_runtime = session.client('bedrock-runtime')
payload = json.dumps({
"max_tokens": settings.ZEROSHOT_MAX_TOKENS,
"top_p": settings.ZEROSHOT_TOP_P,
"top_k": settings.ZEROSHOT_TOP_K,
"stop": settings.ZEROSHOT_STOP,
"temperature": settings.ZEROSHOT_TEMPERATURE,
"prompt": prompt
})

bedrock_response = bedrock_runtime.invoke_model(
body=payload,
contentType='application/json',
accept='application/json',
modelId=settings.ZEROSHOT_BEDROCK_MODEL_ID,
trace='ENABLED'
)

classification = json.loads(bedrock_response['body'].read().decode('utf-8'))

classification_formatter = FormatClassification(classification)
formatted_classification = classification_formatter.get_classification(self.zeroshot_data)

response["output"] = formatted_classification
return response


57 changes: 8 additions & 49 deletions bothub/api/v2/zeroshot/views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import requests
import logging
import json
import traceback

from django.conf import settings

Expand All @@ -16,8 +17,7 @@
ZeroshotLogs
)

from .usecases.format_prompt import FormatPrompt
from .usecases.format_classification import FormatClassification
from .usecases.invoke_model import InvokeModel

from bothub.api.v2.zeroshot.permissions import ZeroshotTokenPermission

Expand Down Expand Up @@ -84,63 +84,22 @@ class ZeroShotFastPredictAPIView(APIView):
permission_classes = [ZeroshotTokenPermission]

def post(self, request):

data = request.data

prompt_formatter = FormatPrompt()

language = data.get("language", prompt_formatter.get_default_language())
prompt = prompt_formatter.generate_prompt(language, data)

payload = json.dumps({
"input": {
"prompt": prompt,
"sampling_params": {
"max_tokens": settings.ZEROSHOT_MAX_TOKENS,
"n": settings.ZEROSHOT_N,
"top_p": settings.ZEROSHOT_TOP_P,
"tok_k": settings.ZEROSHOT_TOK_K,
"temperature": settings.ZEROSHOT_TEMPERATURE,
"do_sample": settings.ZEROSHOT_DO_SAMPLE,
"stop": settings.ZEROSHOT_STOP
}

}
})

headers = {
"Content-Type": "application/json; charset: utf-8",
"Authorization": f"Bearer {settings.ZEROSHOT_TOKEN}",
}
response_nlp = None
try:
url = settings.ZEROSHOT_BASE_NLP_URL
if len(settings.ZEROSHOT_SUFFIX) > 0:
url += settings.ZEROSHOT_SUFFIX
response_nlp = requests.post(
headers=headers,
url=url,
data=payload
)

response = {"output": {}}
if response_nlp.status_code == 200:
classification = response_nlp.json()
classification_formatter = FormatClassification(language, classification)
formatted_classification = classification_formatter.get_classify(data)

response["output"] = formatted_classification
invoke_model = InvokeModel(data)
response = invoke_model.invoke()

ZeroshotLogs.objects.create(
text=data.get("text"),
classification=response["output"].get("classification"),
other=response["output"].get("other", False),
options=data.get("options"),
nlp_log=str(response_nlp.json()),
nlp_log=str(json.dumps(response)),
language=data.get("language")
)

return Response(status=response_nlp.status_code, data=response if response_nlp.status_code == 200 else {"error": response_nlp.text})
return Response(status=200, data=response if response.get("output") else {"error": response})
except Exception as error:
traceback.print_exc()
logger.error(f"[ - ] Zeroshot fast predict: {error}")
return Response(status=response_nlp.status_code if response_nlp else 500, data={"error": error})
return Response(status=500, data={"error": str(error)})
12 changes: 9 additions & 3 deletions bothub/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,11 @@
ZEROSHOT_MAX_TOKENS = (int, 20),
ZEROSHOT_N = (int, 1),
ZEROSHOT_TOP_P = (float, 0.95),
ZEROSHOT_TOK_K = (int, 10),
ZEROSHOT_TOP_K = (int, 10),
ZEROSHOT_TEMPERATURE = (float, 0.1),
ZEROSHOT_DO_SAMPLE = (bool, False),
ZEROSHOT_STOP = (str, "\n"),
ZEROSHOT_STOP = (list, ["<|end_of_text|>", "<|eot_id|>"]),
ZEROSHOT_MODEL_BACKEND = (str, "runpod"),
)

# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
Expand Down Expand Up @@ -720,7 +721,12 @@
ZEROSHOT_MAX_TOKENS = env.int("ZEROSHOT_MAX_TOKENS")
ZEROSHOT_N = env.int("ZEROSHOT_N")
ZEROSHOT_TOP_P = env.float("ZEROSHOT_TOP_P")
ZEROSHOT_TOK_K = env.int("ZEROSHOT_TOK_K")
ZEROSHOT_TOP_K = env.int("ZEROSHOT_TOP_K")
ZEROSHOT_TEMPERATURE = env.float("ZEROSHOT_TEMPERATURE")
ZEROSHOT_DO_SAMPLE = env.bool("ZEROSHOT_DO_SAMPLE")
ZEROSHOT_STOP = env.str("ZEROSHOT_STOP")
ZEROSHOT_MODEL_BACKEND = env.str("ZEROSHOT_MODEL_BACKEND")
ZEROSHOT_BEDROCK_AWS_KEY = env.str("ZEROSHOT_BEDROCK_AWS_KEY")
ZEROSHOT_BEDROCK_AWS_SECRET = env.str("ZEROSHOT_BEDROCK_AWS_SECRET")
ZEROSHOT_BEDROCK_AWS_REGION = env.str("ZEROSHOT_BEDROCK_AWS_REGION")
ZEROSHOT_BEDROCK_MODEL_ID = env.str("ZEROSHOT_BEDROCK_MODEL_ID")
Loading

0 comments on commit 00f223d

Please sign in to comment.