Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

add stock market notebook / library.json #876

Merged
merged 3 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llama_hub/llama_packs/library.json
Original file line number Diff line number Diff line change
Expand Up @@ -257,5 +257,10 @@
"id": "llama_packs/multi_tenancy_rag",
"author": "ravi03071991",
"keywords": ["multi-tenancy", "multi", "tenancy", "rag"]
},
"StockMarketDataQueryEnginePack": {
"id": "llama_packs/stock_market_data_query_engine",
"author": "anoopshrma",
"keywords": ["stock", "market", "data", "query", "engine"]
}
}
18 changes: 14 additions & 4 deletions llama_hub/llama_packs/stock_market_data_query_engine/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional

from llama_index.llama_pack.base import BaseLlamaPack
from llama_index.schema import IndexNode
Expand All @@ -7,6 +7,9 @@
from llama_index.retrievers import RecursiveRetriever
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.llms.llm import LLM
from llama_index.llms import OpenAI
from llama_index.service_context import ServiceContext


class StockMarketDataQueryEnginePack(BaseLlamaPack):
Expand All @@ -15,6 +18,7 @@ class StockMarketDataQueryEnginePack(BaseLlamaPack):
def __init__(
self,
tickers: List[str],
llm: Optional[LLM] = None,
**kwargs: Any,
):
self.tickers = tickers
Expand All @@ -39,8 +43,11 @@ def __init__(
stocks_market_data.append(hist)
self.stocks_market_data = stocks_market_data

service_context = ServiceContext.from_defaults(llm=llm or OpenAI(model="gpt-4"))

df_price_query_engines = [
PandasQueryEngine(stock) for stock in stocks_market_data
PandasQueryEngine(stock, service_context=service_context)
for stock in stocks_market_data
]

summaries = [f"{ticker} historical market data" for ticker in tickers]
Expand All @@ -55,7 +62,9 @@ def __init__(
for idx, df_engine in enumerate(df_price_query_engines)
}

stock_price_vector_index = VectorStoreIndex(df_price_nodes)
stock_price_vector_index = VectorStoreIndex(
df_price_nodes, service_context=service_context
)
stock_price_vector_retriever = stock_price_vector_index.as_retriever(
similarity_top_k=1
)
Expand All @@ -69,7 +78,8 @@ def __init__(

stock_price_response_synthesizer = get_response_synthesizer(
# service_context=service_context,
response_mode="compact"
response_mode="compact",
service_context=service_context,
)

stock_price_query_engine = RetrieverQueryEngine.from_args(
Expand Down
Loading