Skip to content

Commit

Permalink
[LT-4894] Add RAG class (#262)
Browse files Browse the repository at this point in the history
* add rag class
* address comments and add test;
* init for import
  • Loading branch information
isaac-chung authored Jan 8, 2024
1 parent f41e754 commit c6f1052
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 0 deletions.
3 changes: 3 additions & 0 deletions clarifai/rag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .rag import RAG

__all__ = ["RAG"]
129 changes: 129 additions & 0 deletions clarifai/rag/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from datetime import datetime

import yaml
from google.protobuf.struct_pb2 import Struct

from clarifai.client.app import App
from clarifai.client.model import Model
from clarifai.client.user import User
from clarifai.client.workflow import Workflow
from clarifai.utils.logging import get_logger


class RAG:
"""
RAG is a class for Retrieval Augmented Generation.
Example:
>>> from clarifai.rag import RAG
>>> rag_agent = RAG()
"""

def __init__(self,
workflow_url: str = None,
workflow: Workflow = None,
base_url: str = "https://api.clarifai.com",
pat: str = None,
**kwargs):
"""Initialize an empty or existing RAG.
"""
self.logger = get_logger(logger_level="INFO", name=__name__)
if workflow_url is not None and workflow is None:
self.logger.info("workflow_url:%s", workflow_url)
w = Workflow(workflow_url, base_url=base_url, pat=pat)
self._prompt_workflow = w
self._app = App(app_id=w.app_id, base_url=w.base, pat=w.pat)
elif workflow_url is None and workflow is not None:
self._prompt_workflow = workflow
self._app = App(app_id=workflow.app_id, base_url=workflow.base, pat=workflow.pat)

@classmethod
def setup(cls,
user_id: str = None,
llm_url: str = "https://clarifai.com/mistralai/completion/models/mistral-7B-Instruct",
base_workflow: str = "Text",
workflow_yaml_filename: str = 'prompter_wf.yaml',
base_url: str = "https://api.clarifai.com",
pat: str = None,
**kwargs):
"""Creates an app with `Text` as base workflow, create prompt model, create prompt workflow.
Example:
>>> from clarifai.rag import RAG
>>> rag_agent = RAG.setup()
"""
user = User(user_id=user_id, base_url=base_url, pat=pat)
llm = Model(llm_url)

params = Struct()
params.update({
"prompt_template":
"Context information is below:\n{data.hits}\nGiven the context information and not prior knowledge, answer the query.\nQuery: {data.text.raw}\nAnswer: "
})
prompter_model_params = {"params": params}

## Create an App
now_ts = str(int(datetime.now().timestamp()))
app_id = f"rag_app_{now_ts}"
app = user.create_app(app_id=app_id, base_workflow=base_workflow)

## Create rag-prompter model and version
prompter_model = app.create_model(
model_id=f"rag_prompter_{now_ts}", model_type_id="rag-prompter")
prompter_model = prompter_model.create_version(output_info=prompter_model_params)

## Generate a tmp yaml file for workflow creation
workflow_id = f"rag-wf-{now_ts}"
workflow_dict = {
"workflow": {
"id":
workflow_id,
"nodes": [{
"id": "rag-prompter",
"model": {
"model_id": prompter_model.id,
"model_version_id": prompter_model.model_version.id
}
}, {
"id": "llm",
"model": {
"model_id": llm.id,
"user_id": llm.user_id,
"app_id": llm.app_id
},
"node_inputs": [{
"node_id": "rag-prompter"
}]
}]
}
}
with open(workflow_yaml_filename, 'w') as out_file:
yaml.dump(workflow_dict, out_file, default_flow_style=False)

## Create prompt workflow
wf = app.create_workflow(config_filepath=workflow_yaml_filename)
del user, llm, prompter_model, prompter_model_params
return cls(workflow=wf)

# TODO: Implement this.
def upload():
"""Does the following:
- Read from a local directory or public url or local filename.
- Parse the document(s) into chunks.
- Ingest chunks into the app with metadata.
Example:
>>> from clarifai.rag import RAG
>>> rag_agent = RAG().setup()
>>> rag_agent.upload("~/work/docs")
>>> rag_agent.upload("~/work/docs/manual.pdf")
"""
pass

# TODO: Implement this.
def chat(message: str) -> str:
"""Call self._prompt_workflow.predict_by_bytes.
This will pass back the workflow state ID for the server to store chat state.
"""
pass
30 changes: 30 additions & 0 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
from collections import namedtuple

import pytest

from clarifai.rag import RAG
from clarifai.urls.helper import ClarifaiUrlHelper

CREATE_APP_USER_ID = os.environ["CLARIFAI_USER_ID"]

auth_obj = namedtuple("auth", "ui")


@pytest.mark.requires_secrets
class TestRAG:
workflow_url = ""
app_id = ""

def test_setup(self):
app = RAG.setup(user_id=CREATE_APP_USER_ID)
wf = app._prompt_workflow
assert len(wf.workflow_info.nodes) == 2
auth = auth_obj(ui="https://clarifai.com")
self.workflow_url = ClarifaiUrlHelper(auth).clarifai_url(wf.user_id, wf.app_id, "workflows",
wf.id)
self.app_id = app._app.id

## test_from_existing_workflow
app = RAG(workflow_url=self.workflow_url)
assert app._app.id == self.app_id

0 comments on commit c6f1052

Please sign in to comment.