Skip to content

Commit

Permalink
feat: ssm.rag() w load, split, embed, store
Browse files Browse the repository at this point in the history
  • Loading branch information
lpm0073 committed Nov 30, 2023
1 parent a3b1c17 commit 2335d22
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 20 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Netec Large Language Models

A Python [LangChain](https://www.langchain.com/) - [Pinecone](https://docs.pinecone.io/docs/python-client) proof of concept LLM to manage sales support inquiries on the Netec course catalogue.
A Python [LangChain](https://www.langchain.com/) - [Pinecone](https://docs.pinecone.io/docs/python-client) proof of concept Retrieval Augmented Generation (RAG) models using sales support PDF documents.

See:

- [LangChain RAG](https://python.langchain.com/docs/use_cases/question_answering/)
- [LangChain Document Loaders](https://python.langchain.com/docs/modules/data_connection/document_loaders/pdf)
- [LanchChain Caching](https://python.langchain.com/docs/modules/model_io/llms/llm_caching)

## Installation

Expand Down Expand Up @@ -28,6 +34,9 @@ python3 -m models.examples.training_services "Microsoft certified Azure AI engin

# example 4 - prompted assistant
python3 -m models.examples.training_services_oracle "Oracle database administrator"

# example 5 - RAG
python3 -m models.examples.rag "./data/" "What is Accounting Based Valuation?"
```

## Requirements
Expand Down
17 changes: 17 additions & 0 deletions models/examples/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
"""Sales Support Model (SSM) Retrieval Augmented Generation (RAG)"""
import argparse

from ..ssm import SalesSupportModel


ssm = SalesSupportModel()

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RAG example")
parser.add_argument("filepath", type=str, help="Location of PDF documents")
parser.add_argument("prompt", type=str, help="A question about the PDF contents")
args = parser.parse_args()

result = ssm.rag(filepath=args.filepath, prompt=args.prompt)
print(result)
2 changes: 1 addition & 1 deletion models/prompt_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def oracle_training_services(self) -> PromptTemplate:
template = (
self.sales_role
+ """
Note that Netec is the exclusive provide of Oracle training services
Note that Netec is the exclusive provider in Latin America of Oracle training services
for the 6 levels of Oracle Certification credentials: Oracle Certified Junior Associate (OCJA),
Oracle Certified Associate (OCA), Oracle Certified Professional (OCP),
Oracle Certified Master (OCM), Oracle Certified Expert (OCE) and
Expand Down
101 changes: 83 additions & 18 deletions models/ssm.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,48 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-few-public-methods
"""Sales Support Model (SSM) for the LangChain project."""

"""
Sales Support Model (SSM) for the LangChain project.
See: https://python.langchain.com/docs/modules/model_io/llms/llm_caching
https://python.langchain.com/docs/modules/data_connection/document_loaders/pdf
"""

import glob
import os
from typing import ClassVar, List

import pinecone
from langchain import hub
from langchain.cache import InMemoryCache

# prompting and chat
from langchain.chat_models import ChatOpenAI

# document loading
from langchain.document_loaders import PyPDFLoader

# embedding
from langchain.embeddings import OpenAIEmbeddings

# vector database
from langchain.globals import set_llm_cache
from langchain.llms.openai import OpenAI
from langchain.prompts import PromptTemplate
from langchain.schema import HumanMessage, SystemMessage # AIMessage (not used)
from langchain.schema import HumanMessage, StrOutputParser, SystemMessage
from langchain.schema.runnable import RunnablePassthrough
from langchain.text_splitter import Document, RecursiveCharacterTextSplitter
from langchain.vectorstores.pinecone import Pinecone
from pydantic import BaseModel, ConfigDict, Field # ValidationError

# this project
from models.const import Credentials


###############################################################################
# initializations
###############################################################################
DEFAULT_MODEL_NAME = "text-davinci-003"
pinecone.init(api_key=Credentials.PINECONE_API_KEY, environment=Credentials.PINECONE_ENVIRONMENT)
set_llm_cache(InMemoryCache())


class SalesSupportModel(BaseModel):
Expand All @@ -31,24 +55,17 @@ class SalesSupportModel(BaseModel):
default_factory=lambda: ChatOpenAI(
api_key=Credentials.OPENAI_API_KEY,
organization=Credentials.OPENAI_API_ORGANIZATION,
cache=True,
max_retries=3,
model="gpt-3.5-turbo",
temperature=0.3,
temperature=0.0,
)
)

# embeddings
text_splitter: RecursiveCharacterTextSplitter = Field(
default_factory=lambda: RecursiveCharacterTextSplitter(
chunk_size=100,
chunk_overlap=0,
)
)

texts_splitter_results: List[Document] = Field(None, description="Text splitter results")
pinecone_search: Pinecone = Field(None, description="Pinecone search")
pinecone_index_name: str = Field(default="netec-ssm", description="Pinecone index name")
openai_embedding: OpenAIEmbeddings = Field(default_factory=lambda: OpenAIEmbeddings(model="ada"))
openai_embedding: OpenAIEmbeddings = Field(OpenAIEmbeddings())
query_result: List[float] = Field(None, description="Vector database query result")

def cached_chat_request(self, system_message: str, human_message: str) -> SystemMessage:
Expand All @@ -68,24 +85,72 @@ def prompt_with_template(self, prompt: PromptTemplate, concept: str, model: str

def split_text(self, text: str) -> List[Document]:
"""Split text."""
# pylint: disable=no-member
retval = self.text_splitter.create_documents([text])
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=100,
chunk_overlap=0,
)
retval = text_splitter.create_documents([text])
return retval

def embed(self, text: str) -> List[float]:
"""Embed."""
texts_splitter_results = self.split_text(text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=100,
chunk_overlap=0,
)
texts_splitter_results = text_splitter.create_documents([text])
embedding = texts_splitter_results[0].page_content
# pylint: disable=no-member
self.openai_embedding.embed_query(embedding)

self.pinecone_search = Pinecone.from_documents(
texts_splitter_results,
embedding=self.openai_embedding,
index_name=self.pinecone_index_name,
index_name=Credentials.PINECONE_INDEX_NAME,
)

def rag(self, filepath: str, prompt: str):
"""
Embed PDF.
1. Load PDF document text data
2. Split into pages
3. Embed each page
4. Store in Pinecone
"""

# pylint: disable=unused-variable
def format_docs(docs):
"""Format docs."""
return "\n\n".join(doc.page_content for doc in docs)

for pdf_file in glob.glob(os.path.join(filepath, "*.pdf")):
loader = PyPDFLoader(file_path=pdf_file)
docs = loader.load()
for doc in docs:
self.embed(doc.page_content)

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
vectorstore = Pinecone.from_documents(documents=splits, embedding=self.openai_embedding)
retriever = vectorstore.as_retriever()
prompt = hub.pull("rlm/rag-prompt")

rag_chain = (
{"context": retriever | self.format_docs, "question": RunnablePassthrough()}
| prompt
| self.chat
| StrOutputParser()
)

return rag_chain.invoke(prompt)

def embedded_prompt(self, prompt: str) -> List[Document]:
"""Embedded prompt."""
"""
Embedded prompt.
1. Retrieve prompt: Given a user input, relevant splits are retrieved
from storage using a Retriever.
2. Generate: A ChatModel / LLM produces an answer using a prompt that includes
the question and the retrieved data
"""
result = self.pinecone_search.similarity_search(prompt)
return result
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ codespell==2.2.6
python-dotenv==1.0.0
pydantic==2.5.2
langchain==0.0.343
openai==1.3.5
pinecone-client==2.2.4
pypdf==3.17.1
tiktoken==0.5.1

0 comments on commit 2335d22

Please sign in to comment.