Skip to content

Commit

Permalink
Merge pull request #24 from jonfairbanks/develop
Browse files Browse the repository at this point in the history
Updated Docstrings & Demo GIF
  • Loading branch information
jonfairbanks authored Mar 6, 2024
2 parents 116fa25 + 60e6946 commit e8aa361
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 30 deletions.
Binary file modified demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions docs/todo.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# To Do

Below is a rough outline of proposesd features and outstanding issues that are being tracked.
Below is a rough outline of proposed features and outstanding issues that are being tracked.

Although not final, items are generally sorted from highest to lowest priority.

Expand Down Expand Up @@ -35,7 +35,7 @@ Although not final, items are generally sorted from highest to lowest priority.
- [x] Enable Caching
- [ ] Swap Repo & Website input to [Streamlit-Tags](https://gagan3012-streamlit-tags-examplesapp-7aiy65.streamlit.app)
- [ ] Allow Users to Set LLM Settings
- [x] System Prompt
- [ ] System Prompt (needs more work)
- [x] Chat Mode
- [ ] Temperature
- [x] top_k
Expand Down
39 changes: 37 additions & 2 deletions utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def save_uploaded_file(uploaded_file: bytes, save_dir: str):
Args:
uploaded_file (BytesIO): The uploaded file content.
save_dir (str): The directory where the file will be saved.
Returns:
None
Raises:
Exception: If there is an error saving the file to disk.
"""
try:
if not os.path.exists(save_dir):
Expand All @@ -43,6 +49,18 @@ def save_uploaded_file(uploaded_file: bytes, save_dir: str):


def validate_github_repo(repo: str):
"""
Validates whether a GitHub repository exists.
Args:
repo (str): The name of the GitHub repository.
Returns:
True if the repository exists, False otherwise.
Raises:
Exception: If there is an error validating the repository.
"""
repo_endpoint = "https://github.com/" + repo + ".git"
resp = requests.head(repo_endpoint)
if resp.status_code() == 200:
Expand All @@ -62,8 +80,14 @@ def clone_github_repo(repo: str):
"""
Clones a GitHub repository.
Parameters:
Args:
repo (str): The name of the GitHub repository.
Returns:
True if the repository is cloned successfully, False otherwise.
Raises:
Exception: If there is an error cloning the repository.
"""
repo_endpoint = "https://github.com/" + repo + ".git"
if repo_endpoint is not None:
Expand All @@ -90,7 +114,18 @@ def clone_github_repo(repo: str):


def get_file_metadata(file_path):
"""Returns a dictionary containing various metadata for the specified file."""
"""
Extracts various metadata for the specified file.
Args:
file_path (str): The path to the file.
Returns:
A dictionary containing the extracted metadata.
Raises:
Exception: If there is an error extracting the metadata.
"""
try:
with ExifToolHelper() as et:
for d in et.get_metadata(file_path):
Expand Down
102 changes: 96 additions & 6 deletions utils/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@
def setup_embedding_model(
model: str,
):
"""
Sets up an embedding model using the Hugging Face library.
Args:
model (str): The name of the embedding model to use.
Returns:
An instance of the HuggingFaceEmbedding class, configured with the specified model and device.
Raises:
ValueError: If the specified model is not a valid embedding model.
Notes:
The `device` parameter can be set to 'cpu' or 'cuda' to specify the device to use for the embedding computations. If 'cuda' is used and CUDA is available, the embedding model will be run on the GPU. Otherwise, it will be run on the CPU.
"""
device = 'cpu' if not cuda.is_available() else 'cuda'
embed_model = HuggingFaceEmbedding(
model_name=model,
Expand All @@ -48,11 +63,32 @@ def setup_embedding_model(

def create_service_context(
llm, # TODO: Determine type
system_prompt: str = None, # TODO: What are the implications of no system prompt being passed?
system_prompt: str = None,
embed_model: str = "BAAI/bge-large-en-v1.5",
chunk_size: int = 1024, # Llama-Index default is 1024
chunk_overlap: int = 200, # Llama-Index default is 200
):
"""
Creates a service context for the Llama language model.
Args:
llm (tbd): The Ollama language model to use.
system_prompt (str): An optional string that can be used as the system prompt when generating text. If no system prompt is passed, the default value will be used.
embed_model (str): The name of the embedding model to use. Can also be a path to a saved embedding model.
chunk_size (int): The size of each chunk of text to generate. Defaults to 1024.
chunk_overlap (int): The amount of overlap between adjacent chunks of text. Defaults to 200.
Returns:
A ServiceContext instance, configured with the specified Llama model, system prompt, and embedding model.
Raises:
ValueError: If the specified Llama model is not a valid Llama model.
ValueError: If the specified embed_model is not a valid embedding model.
Notes:
The `embed_model` parameter can be set to a path to a saved embedding model, or to a string representing the name of the embedding model to use. If the `embed_model` parameter is set to a path, it will be loaded and used to create the service context. Otherwise, it will be created using the specified name.
The `chunk_size` and `chunk_overlap` parameters can be adjusted to control how much text is generated in each chunk and how much overlap there is between chunks.
"""
formatted_embed_model = f"local:{embed_model}"
try:
embedding_model = setup_embedding_model(embed_model)
Expand Down Expand Up @@ -83,6 +119,21 @@ def create_service_context(


def load_documents(data_dir: str):
"""
Loads documents from a directory of files.
Args:
data_dir (str): The path to the directory containing the documents to be loaded.
Returns:
A list of documents, where each document is a string representing the content of the corresponding file.
Raises:
Exception: If there is an error creating the data index.
Notes:
The `data_dir` parameter should be a path to a directory containing files that represent the documents to be loaded. The function will iterate over all files in the directory, and load their contents into a list of strings.
"""
try:
files = SimpleDirectoryReader(input_dir=data_dir, recursive=True)
documents = files.load_data(files)
Expand All @@ -105,13 +156,34 @@ def load_documents(data_dir: str):

@st.cache_data(show_spinner=False)
def create_index(_documents, _service_context):
index = VectorStoreIndex.from_documents(
documents=_documents, service_context=_service_context, show_progress=True
)
"""
Creates an index from the provided documents and service context.
logs.log.info("Index created from loaded documents successfully")
Args:
documents (list[str]): A list of strings representing the content of the documents to be indexed.
service_context (ServiceContext): The service context to use when creating the index.
return index
Returns:
An instance of `VectorStoreIndex`, containing the indexed data.
Raises:
Exception: If there is an error creating the index.
Notes:
The `documents` parameter should be a list of strings representing the content of the documents to be indexed. The `service_context` parameter should be an instance of `ServiceContext`, providing information about the Llama model and other configuration settings for the index.
"""

try:
index = VectorStoreIndex.from_documents(
documents=_documents, service_context=_service_context, show_progress=True
)

logs.log.info("Index created from loaded documents successfully")

return index
except Exception as err:
logs.log.error(f"Index creation failed: {err}")
return False


###################################
Expand All @@ -123,6 +195,24 @@ def create_index(_documents, _service_context):

@st.cache_data(show_spinner=False)
def create_query_engine(_documents, _service_context):
"""
Creates a query engine from the provided documents and service context.
Args:
documents (list[str]): A list of strings representing the content of the documents to be indexed.
service_context (ServiceContext): The service context to use when creating the index.
Returns:
An instance of `QueryEngine`, containing the indexed data and allowing for querying of the data using a variety of parameters.
Raises:
Exception: If there is an error creating the query engine.
Notes:
The `documents` parameter should be a list of strings representing the content of the documents to be indexed. The `service_context` parameter should be an instance of `ServiceContext`, providing information about the Llama model and other configuration settings for the index.
This function uses the `create_index` function to create an index from the provided documents and service context, and then creates a query engine from the resulting index. The `query_engine` parameter is used to specify the parameters of the query engine, including the number of top-ranked items to return (`similarity_top_k`), the response mode (`response_mode`), and the service context (`service_context`).
"""
try:
index = create_index(_documents, _service_context)

Expand Down
17 changes: 16 additions & 1 deletion utils/logs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
import logging
import sys

from typing import Union

def setup_logger(log_file="local-rag.log", level=logging.INFO):

def setup_logger(log_file: str = "local-rag.log", level: Union[int, str] = logging.INFO):
"""
Sets up a logger for this module.
Args:
log_file (str, optional): The file to which the logs should be written. Defaults to "local-rag.log".
level (str, optional): The logging level at which to log messages. Defaults to logging.INFO.
Returns:
logging.Logger: The set up logger.
Notes:
This function sets up a logger for this module using the `logging` library. It sets the logging level to the specified level, and adds handlers for both file and console output. The log file can be customized by passing a different name as the argument.
"""
logger = logging.getLogger(__name__)
logger.setLevel(level)

Expand Down
83 changes: 64 additions & 19 deletions utils/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
os.environ["OPENAI_API_KEY"] = "sk-abc123"

from llama_index.llms.ollama import Ollama
from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine

###################################
#
Expand All @@ -25,11 +26,21 @@ def create_client(host: str):
- host (str): The hostname or IP address of the Ollama server.
Returns:
- ollama.Client: An instance of the Ollama client.
- An instance of the Ollama client.
Raises:
- Exception: If there is an error creating the client.
Notes:
This function creates a client for interacting with the Ollama API using the `ollama` library. It takes a single parameter, `host`, which should be the hostname or IP address of the Ollama server. The function returns an instance of the Ollama client, or raises an exception if there is an error creating the client.
"""
client = ollama.Client(host=host)
logs.log.info("Ollama chat client created successfully")
return client
try:
client = ollama.Client(host=host)
logs.log.info("Ollama chat client created successfully")
return client
except Exception as err:
logs.log.error(f"Failed to create Ollama client: {err}")
return False


###################################
Expand All @@ -44,16 +55,37 @@ def get_models():
Retrieves a list of available language models from the Ollama server.
Returns:
- models: A list of available language model names.
- models (list[str]): A list of available language model names.
Raises:
- Exception: If there is an error retrieving the list of models.
Notes:
This function retrieves a list of available language models from the Ollama server using the `ollama` library. It takes no parameters and returns a list of available language model names.
The function raises an exception if there is an error retrieving the list of models.
Side Effects:
- st.session_state["ollama_models"] is set to the list of available language models.
"""
chat_client = create_client(st.session_state["ollama_endpoint"])
data = chat_client.list()
models = []
for model in data["models"]:
models.append(model["name"])
logs.log.info("Ollama models loaded successuflly")
st.session_state["ollama_models"] = models
return models
try:
chat_client = create_client(st.session_state["ollama_endpoint"])
data = chat_client.list()
models = []
for model in data["models"]:
models.append(model["name"])

st.session_state["ollama_models"] = models

if len(models) > 0:
logs.log.info("Ollama models loaded successfully")
else:
logs.log.warn("Ollama did not return any models. Make sure to download some!")

return models
except Exception as err:
logs.log.error(f"Failed to retrieve Ollama model list: {err}")
return False


###################################
Expand Down Expand Up @@ -120,20 +152,33 @@ def chat(prompt: str):
###################################


def context_chat(prompt: str, query_engine):
def context_chat(prompt: str, query_engine: RetrieverQueryEngine):
"""
Initiates a chat with context using the Ollama language model and index.
Initiates a chat with context using the Llama-Index query_engine.
Parameters:
- prompt (str): The starting prompt for the conversation.
- query_engine (str): TODO: Write this section
- query_engine (RetrieverQueryEngine): The Llama-Index query engine to use for retrieving answers.
Yields:
- str: Successive chunks of conversation from the Ollama model with context.
- str: Successive chunks of conversation from the Llama-Index model with context.
Raises:
- Exception: If there is an error retrieving answers from the Llama-Index model.
Notes:
This function initiates a chat with context using the Llama-Index language model and index.
It takes two parameters, `prompt` and `query_engine`, which should be the starting prompt for the conversation and the Llama-Index query engine to use for retrieving answers, respectively.
The function returns an iterable yielding successive chunks of conversation from the Llama-Index index with context.
If there is an error retrieving answers from the Llama-Index instance, the function raises an exception.
Side Effects:
- The chat conversation is generated and returned as successive chunks of text.
"""

# print(type(query_engine)) # <class 'llama_index.core.query_engine.retriever_query_engine.RetrieverQueryEngine'>

try:
stream = query_engine.query(prompt)
for text in stream.response_gen:
Expand Down
Loading

0 comments on commit e8aa361

Please sign in to comment.