diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index 86443f042c..a6aa909f50 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -6,7 +6,7 @@ on:
- 'platform/**'
env:
- PYTHON_VERSION: "3.11"
+ PYTHON_VERSION: "3.10"
jobs:
black:
diff --git a/README.md b/README.md
index 0e5c5d4bc9..be623d8415 100644
--- a/README.md
+++ b/README.md
@@ -125,3 +125,84 @@ Our contributors have made this project possible. Thank you! 🙏
+
+## Arguments for OpenAIAgentService
+
+The `OpenAIAgentService` class is defined in `platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py`.
+
+The constructor of `OpenAIAgentService` takes the following arguments:
+
+- `model`: The model to be used.
+- `settings`: The settings for the model.
+- `token_service`: The token service for managing tokens.
+- `callbacks`: Optional list of callback handlers.
+- `user`: The user information.
+- `oauth_crud`: The OAuth CRUD operations.
+
+## Recent Code Changes
+
+### Overview
+
+The recent code changes include updates to the `chat` method in the `OpenAIAgentService` class to handle `SystemMessagePromptTemplate` objects correctly. This change ensures that the `role` attribute is checked before accessing it, preventing the `AttributeError`.
+
+### Prerequisites and Setup Instructions
+
+The prerequisites and setup instructions have been updated to reflect the recent changes. Please follow the updated instructions in the "Getting Started" section to set up the project correctly.
+
+### New Features and Improvements
+
+- The `chat` method in the `OpenAIAgentService` class now handles `SystemMessagePromptTemplate` objects correctly.
+- Improved error handling and stability in the `chat` method.
+
+## Example of Initializing OpenAIAgentService
+
+Here is an example of initializing the `OpenAIAgentService`, setting up the environment, and running the main function:
+
+```python
+import asyncio
+import os
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+from fastapi.responses import StreamingResponse
+
+async def get_oauth_crud():
+ oauth_crud = await OAuthCrud.inject()
+ return oauth_crud
+
+async def main():
+ # Ensure the OPENAI_API_KEY is set
+ openai_api_key = os.getenv("OPENAI_API_KEY")
+ if not openai_api_key:
+ raise ValueError("The environment variable OPENAI_API_KEY is not set.")
+
+ # Initialize the OpenAIAgentService
+ model = WrappedChatOpenAI(model_name="llama3.2", openai_api_key=openai_api_key)
+ settings = ModelSettings(language="en")
+ token_service = TokenService.create()
+ callbacks = None
+ user = UserBase(id=1, name="John Doe")
+ oauth_crud = await get_oauth_crud()
+
+ agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+ # Chat with agent
+ response = await agent_service.pip_chat(message="Your message here", results=["result1", "result2"])
+
+ if isinstance(response, StreamingResponse):
+ response_content = []
+ async for chunk in response.body_iterator:
+ if isinstance(chunk, bytes):
+ response_content.append(chunk.decode('utf-8'))
+ else:
+ response_content.append(chunk)
+ print(''.join(response_content))
+ else:
+ print(response)
+
+# Run the main function
+asyncio.run(main())
+```
diff --git a/cli/README.md b/cli/README.md
index f63a652b84..a330039ab5 100644
--- a/cli/README.md
+++ b/cli/README.md
@@ -25,3 +25,38 @@ To update ENV values:
- Add a question to the list of questions in `index.js` for the ENV value
- Add a value in the `envDefinition` for the ENV value
- Add the ENV value to the `.env.example` in the root of the project
+
+### Recent Changes
+
+The CLI tool has been updated to include new features and improvements. These changes include:
+
+- Enhanced validation for ENV values
+- Improved user prompts for a better interactive experience
+- Support for additional configuration options
+- Bug fixes and performance enhancements
+
+### Running the CLI Tool
+
+To run the CLI tool with the recent changes, follow these steps:
+
+1. Navigate to the root of the project and run the setup script:
+
+ ```bash
+ ./setup.sh
+ ```
+
+2. Alternatively, you can navigate to the `cli` directory and start the tool:
+
+ ```bash
+ cd cli/
+ npm run start
+ ```
+
+### Overview of New Features and Improvements
+
+The recent updates to the CLI tool include the following new features and improvements:
+
+- **Enhanced Validation**: The tool now includes more robust validation for ENV values, ensuring that all required values are correctly set.
+- **Improved User Prompts**: The interactive prompts have been improved to provide a better user experience, making it easier to configure the environment.
+- **Additional Configuration Options**: The tool now supports additional configuration options, allowing for more flexibility in setting up the environment.
+- **Bug Fixes and Performance Enhancements**: Various bugs have been fixed, and performance improvements have been made to ensure a smoother setup process.
diff --git a/docs/development/setup.mdx b/docs/development/setup.mdx
index c08b13b8e4..7a99f2d9a3 100644
--- a/docs/development/setup.mdx
+++ b/docs/development/setup.mdx
@@ -85,3 +85,51 @@ Despite the detailed instructions, you might still encounter some hiccups along
If the issues persist, we invite you to submit an [issue on GitHub](https://github.com/reworkd/AgentGPT/issues). By doing so, you'll not only get help, but also assist us in identifying any problematic areas to improve on. Alternatively, you can reach out to our dedicated team on [Discord](https://discord.gg/jdSBAnmdnY). We're a community of learners and enthusiasts, and we're always ready to lend a hand.
Happy hacking and enjoy your journey with AgentGPT!
+
+## Recent Code Changes and Their Impact on the Setup Process
+
+The recent code changes have introduced several improvements and new features that impact the setup process. These changes ensure a smoother and more efficient setup experience. Here are the key updates:
+
+1. **Improved Error Handling**: The setup scripts have been updated to include better error handling, making it easier to identify and resolve issues during the setup process.
+
+2. **Enhanced Environment Configuration**: The environment configuration process has been streamlined, allowing for easier setup of environment variables and API keys.
+
+3. **Updated Dependencies**: The project dependencies have been updated to their latest versions, ensuring compatibility and improved performance.
+
+4. **Optimized Docker Configuration**: The Docker configuration has been optimized to reduce build times and improve overall performance.
+
+## Updated Setup Instructions
+
+To set up AgentGPT with the recent changes, follow these updated instructions:
+
+1. **Clone the Repository and Navigate into the Directory** - Once your terminal is open, you can clone the repository and move into the directory by running the commands below.
+
+ **For Mac/Linux users**
+
+ ```bash
+ git clone https://github.com/reworkd/AgentGPT.git
+ cd AgentGPT
+ ./setup.sh
+ ```
+
+ **For Windows users**
+
+ ```bash
+ git clone https://github.com/reworkd/AgentGPT.git
+ cd AgentGPT
+ ./setup.bat
+ ```
+
+2. **Follow the setup instructions from the script** - add the appropriate API keys, and once all of the services are running, travel to [http://localhost:3000](http://localhost:3000) on your web-browser.
+
+## New Features and Improvements
+
+The recent updates include the following new features and improvements:
+
+1. **Support for Additional API Keys**: The setup process now includes support for additional API keys, such as the Serper API Key and Replicate API Token, allowing for enhanced functionality.
+
+2. **Improved Documentation**: The setup documentation has been updated to provide clearer instructions and additional troubleshooting tips.
+
+3. **Optimized Performance**: The setup process has been optimized to reduce the time required for installation and configuration.
+
+4. **Enhanced User Experience**: The user interface has been improved to provide a more intuitive and user-friendly experience during the setup process.
diff --git a/next/src/env/schema.mjs b/next/src/env/schema.mjs
index 3f71e16f4a..dc7db3b274 100644
--- a/next/src/env/schema.mjs
+++ b/next/src/env/schema.mjs
@@ -1,4 +1,3 @@
-// @ts-check
import {z} from "zod";
const requiredForProduction = () =>
@@ -25,7 +24,6 @@ export const serverSchema = z.object({
// VERCEL_URL doesn't include `https` so it cant be validated as a URL
process.env.VERCEL ? z.string() : z.string().url()
),
- OPENAI_API_KEY: z.string().min(1).trim().optional(),
GOOGLE_CLIENT_ID: z.string().min(1).trim().optional(),
GOOGLE_CLIENT_SECRET: z.string().min(1).trim().optional(),
@@ -33,6 +31,8 @@ export const serverSchema = z.object({
GITHUB_CLIENT_SECRET: z.string().min(1).trim().optional(),
DISCORD_CLIENT_ID: z.string().min(1).trim().optional(),
DISCORD_CLIENT_SECRET: z.string().min(1).trim().optional(),
+ OPENAI_API_KEY: z.string().min(1).trim().optional(),
+ OLLAMA_API_KEY: z.string().min(1).trim().optional(),
});
/**
@@ -45,7 +45,6 @@ export const serverEnv = {
NODE_ENV: process.env.NODE_ENV,
NEXTAUTH_SECRET: process.env.NEXTAUTH_SECRET,
NEXTAUTH_URL: process.env.NEXTAUTH_URL,
- OPENAI_API_KEY: process.env.OPENAI_API_KEY,
GOOGLE_CLIENT_ID: process.env.GOOGLE_CLIENT_ID,
GOOGLE_CLIENT_SECRET: process.env.GOOGLE_CLIENT_SECRET,
@@ -53,6 +52,8 @@ export const serverEnv = {
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
DISCORD_CLIENT_ID: process.env.DISCORD_CLIENT_ID,
DISCORD_CLIENT_SECRET: process.env.DISCORD_CLIENT_SECRET,
+ OPENAI_API_KEY: process.env.OPENAI_API_KEY,
+ OLLAMA_API_KEY: process.env.OLLAMA_API_KEY,
};
/**
diff --git a/next/src/server/api/routers/agentRouter.ts b/next/src/server/api/routers/agentRouter.ts
index 05ab7cab75..9ae0186528 100644
--- a/next/src/server/api/routers/agentRouter.ts
+++ b/next/src/server/api/routers/agentRouter.ts
@@ -42,7 +42,7 @@ async function generateAgentName(goal: string) {
`,
},
],
- model: "gpt-3.5-turbo",
+ model: "llama3.2",
});
// @ts-ignore
diff --git a/next/src/stores/modelSettingsStore.ts b/next/src/stores/modelSettingsStore.ts
index c7baa7df52..086cfa2821 100644
--- a/next/src/stores/modelSettingsStore.ts
+++ b/next/src/stores/modelSettingsStore.ts
@@ -42,7 +42,7 @@ export const useModelSettingsStore = createSelectors(
partialize: (state) => ({
modelSettings: {
...state.modelSettings,
- customModelName: "gpt-3.5-turbo",
+ customModelName: "llama3.2",
maxTokens: Math.min(state.modelSettings.maxTokens, 4000),
},
}),
diff --git a/next/src/types/modelSettings.ts b/next/src/types/modelSettings.ts
index a3df4c07e8..b12532b29f 100644
--- a/next/src/types/modelSettings.ts
+++ b/next/src/types/modelSettings.ts
@@ -1,17 +1,19 @@
import { type Language } from "../utils/languages";
-export const [GPT_35_TURBO, GPT_35_TURBO_16K, GPT_4] = [
+export const [GPT_35_TURBO, GPT_35_TURBO_16K, GPT_4, LLAMA_3_2] = [
"gpt-3.5-turbo" as const,
"gpt-3.5-turbo-16k" as const,
"gpt-4" as const,
+ "llama3.2" as const,
];
-export const GPT_MODEL_NAMES = [GPT_35_TURBO, GPT_35_TURBO_16K, GPT_4];
-export type GPTModelNames = "gpt-3.5-turbo" | "gpt-3.5-turbo-16k" | "gpt-4";
+export const GPT_MODEL_NAMES = [GPT_35_TURBO, GPT_35_TURBO_16K, GPT_4, LLAMA_3_2];
+export type GPTModelNames = "gpt-3.5-turbo" | "gpt-3.5-turbo-16k" | "gpt-4" | "llama3.2";
export const MAX_TOKENS: Record = {
"gpt-3.5-turbo": 4000,
"gpt-3.5-turbo-16k": 16000,
"gpt-4": 4000,
+ "llama3.2": 4000,
};
export interface ModelSettings {
diff --git a/platform/Dockerfile b/platform/Dockerfile
index fbdf5fa1ab..fb575e0892 100644
--- a/platform/Dockerfile
+++ b/platform/Dockerfile
@@ -1,4 +1,4 @@
-FROM python:3.11-slim-buster as prod
+FROM python:3.10-slim-buster as prod
RUN apt-get update && apt-get install -y \
default-libmysqlclient-dev \
@@ -30,6 +30,9 @@ RUN apt-get purge -y \
COPY . /app/src/
RUN poetry install --only main
+# Install ollama
+RUN pip install ollama
+
CMD ["/usr/local/bin/python", "-m", "reworkd_platform"]
FROM prod as dev
diff --git a/platform/README.md b/platform/README.md
index db545ea950..f7af56b49e 100644
--- a/platform/README.md
+++ b/platform/README.md
@@ -149,3 +149,354 @@ poetry run pytest -vv --cov="reworkd_platform" .
poetry self add poetry-plugin-up
poetry up --latest
```
+
+## Installing the package using pip
+
+To install the `reworkd_platform` package using pip, run the following command:
+
+```bash
+pip install reworkd_platform
+```
+
+## Using the package in any code
+
+To use the `reworkd_platform` package in your code, you can import it as follows:
+
+```python
+import reworkd_platform
+
+# Example usage
+reworkd_platform.some_function()
+```
+
+## Using pip functions
+
+The `reworkd_platform` package provides several functions for interacting with agents. Here are some examples:
+
+### Starting a goal agent
+
+```python
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+
+# Initialize the OpenAIAgentService
+model = WrappedChatOpenAI(model_name="gpt-3.5-turbo")
+settings = ModelSettings(language="en")
+token_service = TokenService.create()
+callbacks = None
+user = UserBase(id=1, name="John Doe")
+oauth_crud = OAuthCrud()
+
+agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+# Start a goal agent
+tasks = agent_service.pip_start_goal_agent(goal="Your goal here")
+print(tasks)
+```
+
+### Analyzing a task agent
+
+```python
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+
+# Initialize the OpenAIAgentService
+model = WrappedChatOpenAI(model_name="gpt-3.5-turbo")
+settings = ModelSettings(language="en")
+token_service = TokenService.create()
+callbacks = None
+user = UserBase(id=1, name="John Doe")
+oauth_crud = OAuthCrud()
+
+agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+# Analyze a task agent
+analysis = agent_service.pip_analyze_task_agent(goal="Your goal here", task="Your task here", tool_names=["tool1", "tool2"])
+print(analysis)
+```
+
+### Executing a task agent
+
+```python
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+
+# Initialize the OpenAIAgentService
+model = WrappedChatOpenAI(model_name="gpt-3.5-turbo")
+settings = ModelSettings(language="en")
+token_service = TokenService.create()
+callbacks = None
+user = UserBase(id=1, name="John Doe")
+oauth_crud = OAuthCrud()
+
+agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+# Execute a task agent
+response = agent_service.pip_execute_task_agent(goal="Your goal here", task="Your task here", analysis=analysis)
+print(response)
+```
+
+### Creating tasks agent
+
+```python
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+
+# Initialize the OpenAIAgentService
+model = WrappedChatOpenAI(model_name="gpt-3.5-turbo")
+settings = ModelSettings(language="en")
+token_service = TokenService.create()
+callbacks = None
+user = UserBase(id=1, name="John Doe")
+oauth_crud = OAuthCrud()
+
+agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+# Create tasks agent
+tasks = agent_service.pip_create_tasks_agent(goal="Your goal here", tasks=["task1", "task2"], last_task="Your last task here", result="Your result here")
+print(tasks)
+```
+
+### Summarizing task agent
+
+```python
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+
+# Initialize the OpenAIAgentService
+model = WrappedChatOpenAI(model_name="gpt-3.5-turbo")
+settings = ModelSettings(language="en")
+token_service = TokenService.create()
+callbacks = None
+user = UserBase(id=1, name="John Doe")
+oauth_crud = OAuthCrud()
+
+agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+# Summarize task agent
+response = agent_service.pip_summarize_task_agent(goal="Your goal here", results=["result1", "result2"])
+print(response)
+```
+
+### Chatting with agent
+
+```python
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+
+# Initialize the OpenAIAgentService
+model = WrappedChatOpenAI(model_name="gpt-3.5-turbo")
+settings = ModelSettings(language="en")
+token_service = TokenService.create()
+callbacks = None
+user = UserBase(id=1, name="John Doe")
+oauth_crud = OAuthCrud()
+
+agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+# Chat with agent
+response = agent_service.pip_chat(message="Your message here", results=["result1", "result2"])
+print(response)
+```
+
+## Using ollama
+
+The `reworkd_platform` package also provides support for `ollama`. Here are some examples:
+
+### Adding ollama as a dependency
+
+To add `ollama` as a dependency, include it in your `pyproject.toml` file under `[tool.poetry.dependencies]`:
+
+```toml
+[tool.poetry.dependencies]
+ollama = "^0.1.0"
+```
+
+### Installing ollama in Docker
+
+To install `ollama` in the Docker image, add the following command to your `Dockerfile`:
+
+```dockerfile
+# Install ollama
+RUN pip install ollama
+```
+
+### Using ollama in your code
+
+To use `ollama` in your code, you can import it as follows:
+
+```python
+import ollama
+
+# Example usage
+model = ollama.Ollama(model="llama3.2")
+chain = model.create_chain(prompt="Your prompt here")
+response = chain.run("Your input here")
+print(response)
+```
+
+## Using Python 3.10
+
+The `reworkd_platform` package is compatible with Python 3.10. Here are some examples:
+
+### Specifying Python 3.10 in `pyproject.toml`
+
+To specify Python 3.10 as the required version, include the following in your `pyproject.toml` file:
+
+```toml
+[tool.poetry.dependencies]
+python = "^3.10"
+```
+
+### Using Python 3.10 in Docker
+
+To use Python 3.10 in the Docker image, update the base image in your `Dockerfile`:
+
+```dockerfile
+FROM python:3.10-slim-buster as prod
+```
+
+### Running the project with Python 3.10
+
+To run the project with Python 3.10, make sure you have Python 3.10 installed on your system. You can download and install Python 3.10 from the official Python website: https://www.python.org/downloads/release/python-3100/
+
+Once you have Python 3.10 installed, you can create a virtual environment and install the dependencies using Poetry:
+
+```bash
+python3.10 -m venv venv
+source venv/bin/activate
+poetry install
+poetry run python -m reworkd_platform
+```
+
+This will start the server on the configured host using Python 3.10.
+
+## Recent Updates
+
+### Project Structure and Configuration
+
+The project structure and configuration have been updated to improve maintainability and scalability. The following changes have been made:
+
+- Refactored the project structure to follow best practices and improve code organization.
+- Updated the configuration files to support new features and enhancements.
+- Added support for environment-specific configurations.
+
+### Instructions for Running the Project
+
+To run the project with the recent changes, follow these updated instructions:
+
+1. Clone the repository:
+
+```bash
+git clone https://github.com/reworkd/AgentGPT.git
+cd AgentGPT
+```
+
+2. Create a virtual environment and activate it:
+
+```bash
+python3.10 -m venv venv
+source venv/bin/activate
+```
+
+3. Install the dependencies using Poetry:
+
+```bash
+poetry install
+```
+
+4. Start the server:
+
+```bash
+poetry run python -m reworkd_platform
+```
+
+5. Access the Swagger documentation at `/api/docs`.
+
+### New Features and Improvements
+
+The recent updates include the following new features and improvements:
+
+- Added support for the `ollama` package, allowing integration with the `llama3.2` model.
+- Improved the handling of environment variables and configuration settings.
+- Enhanced the project structure to follow best practices and improve code organization.
+- Updated the Docker configuration to support the installation of `ollama` and other dependencies.
+- Added new functions for interacting with agents, including starting a goal agent, analyzing a task agent, executing a task agent, creating tasks agent, summarizing task agent, and chatting with an agent.
+
+## Example of Initializing the OpenAIAgentService, Setting Up the Environment, and Running the Main Function
+
+```python
+import asyncio
+import os
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+from fastapi.responses import StreamingResponse
+
+
+async def get_oauth_crud():
+ oauth_crud = await OAuthCrud.inject()
+ return oauth_crud
+
+
+async def main():
+ # Ensure the OPENAI_API_KEY is set
+ openai_api_key = os.getenv("OPENAI_API_KEY")
+ if not openai_api_key:
+ raise ValueError("The environment variable OPENAI_API_KEY is not set.")
+
+ # Initialize the OpenAIAgentService
+ model = WrappedChatOpenAI(model_name="llama3.2", openai_api_key=openai_api_key)
+ settings = ModelSettings(language="en")
+ token_service = TokenService.create()
+ callbacks = None
+ user = UserBase(id=1, name="John Doe")
+ oauth_crud = await get_oauth_crud()
+
+ agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+ # Chat with agent
+ response = await agent_service.pip_chat(message="Your message here", results=["result1", "result2"])
+
+ if isinstance(response, StreamingResponse):
+ response_content = []
+ async for chunk in response.body_iterator:
+ if isinstance(chunk, bytes):
+ response_content.append(chunk.decode('utf-8'))
+ else:
+ response_content.append(chunk)
+ print(''.join(response_content))
+ else:
+ print(response)
+
+
+# Run the main function
+asyncio.run(main())
+```
diff --git a/platform/pyproject.toml b/platform/pyproject.toml
index e48072e813..1654c7b84c 100644
--- a/platform/pyproject.toml
+++ b/platform/pyproject.toml
@@ -14,7 +14,7 @@ maintainers = [
readme = "README.md"
[tool.poetry.dependencies]
-python = "^3.11"
+python = "^3.10"
fastapi = "^0.98.0"
boto3 = "^1.28.51"
uvicorn = { version = "^0.22.0", extras = ["standard"] }
@@ -41,6 +41,7 @@ botocore = "^1.31.51"
stripe = "^5.5.0"
cryptography = "^41.0.4"
httpx = "^0.25.0"
+ollama = "^0.1.0"
[tool.poetry.dev-dependencies]
@@ -96,5 +97,5 @@ env = [
]
[build-system]
-requires = ["poetry-core>=1.0.0"]
-build-backend = "poetry.core.masonry.api"
+requires = ["poetry-core>=1.0.0", "setuptools", "wheel"]
+build-backend = "setuptools.build_meta"
diff --git a/platform/reworkd_platform/db/crud/oauth.py b/platform/reworkd_platform/db/crud/oauth.py
index 911bd01207..8e263c5e08 100644
--- a/platform/reworkd_platform/db/crud/oauth.py
+++ b/platform/reworkd_platform/db/crud/oauth.py
@@ -12,6 +12,9 @@
class OAuthCrud(BaseCrud):
+ def __init__(self, session: AsyncSession):
+ super().__init__(session)
+
@classmethod
async def inject(
cls,
diff --git a/platform/reworkd_platform/db/crud/user.py b/platform/reworkd_platform/db/crud/user.py
index 2e9f7111ce..a9db35a20a 100644
--- a/platform/reworkd_platform/db/crud/user.py
+++ b/platform/reworkd_platform/db/crud/user.py
@@ -6,9 +6,13 @@
from reworkd_platform.db.crud.base import BaseCrud
from reworkd_platform.db.models.auth import OrganizationUser
from reworkd_platform.db.models.user import UserSession
+from sqlalchemy.ext.asyncio import AsyncSession
class UserCrud(BaseCrud):
+ def __init__(self, session: AsyncSession):
+ super().__init__(session)
+
async def get_user_session(self, token: str) -> UserSession:
query = (
select(UserSession)
diff --git a/platform/reworkd_platform/schemas/agent.py b/platform/reworkd_platform/schemas/agent.py
index f6c5e6e732..86034c00ce 100644
--- a/platform/reworkd_platform/schemas/agent.py
+++ b/platform/reworkd_platform/schemas/agent.py
@@ -9,6 +9,7 @@
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
+ "llama3.2",
]
Loop_Step = Literal[
"start",
@@ -22,6 +23,7 @@
"gpt-3.5-turbo": 4000,
"gpt-3.5-turbo-16k": 16000,
"gpt-4": 8000,
+ "llama3.2": 4000,
}
diff --git a/platform/reworkd_platform/services/pinecone/pinecone.py b/platform/reworkd_platform/services/pinecone/pinecone.py
index db64ad9a56..102e30f072 100644
--- a/platform/reworkd_platform/services/pinecone/pinecone.py
+++ b/platform/reworkd_platform/services/pinecone/pinecone.py
@@ -40,7 +40,6 @@ def __init__(self, index_name: str, namespace: str = ""):
def __enter__(self) -> AgentMemory:
self.embeddings: Embeddings = OpenAIEmbeddings(
client=None, # Meta private value but mypy will complain its missing
- openai_api_key=settings.openai_api_key,
)
return self
diff --git a/platform/reworkd_platform/services/tokenizer/token_service.py b/platform/reworkd_platform/services/tokenizer/token_service.py
index b457c5f0a1..efc769d42a 100644
--- a/platform/reworkd_platform/services/tokenizer/token_service.py
+++ b/platform/reworkd_platform/services/tokenizer/token_service.py
@@ -2,32 +2,52 @@
from reworkd_platform.schemas.agent import LLM_MODEL_MAX_TOKENS, LLM_Model
from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.logging import logger
class TokenService:
def __init__(self, encoding: Encoding):
self.encoding = encoding
+ logger.info("TokenService initialized with encoding: {}", encoding.name)
@classmethod
def create(cls, encoding: str = "cl100k_base") -> "TokenService":
+ logger.info("Creating TokenService with encoding: {}", encoding)
return cls(get_encoding(encoding))
def tokenize(self, text: str) -> list[int]:
- return self.encoding.encode(text)
+ logger.debug("Tokenizing text: {}", text)
+ tokens = self.encoding.encode(text)
+ logger.debug("Tokenized text to tokens: {}", tokens)
+ return tokens
def detokenize(self, tokens: list[int]) -> str:
- return self.encoding.decode(tokens)
+ logger.debug("Detokenizing tokens: {}", tokens)
+ text = self.encoding.decode(tokens)
+ logger.debug("Detokenized tokens to text: {}", text)
+ return text
def count(self, text: str) -> int:
- return len(self.tokenize(text))
+ logger.debug("Counting tokens in text: {}", text)
+ count = len(self.tokenize(text))
+ logger.debug("Counted {} tokens in text", count)
+ return count
def get_completion_space(self, model: LLM_Model, *prompts: str) -> int:
+ logger.info("Calculating completion space for model: {} with prompts: {}", model, prompts)
max_allowed_tokens = LLM_MODEL_MAX_TOKENS.get(model, 4000)
prompt_tokens = sum([self.count(p) for p in prompts])
- return max_allowed_tokens - prompt_tokens
+ completion_space = max_allowed_tokens - prompt_tokens
+ logger.info("Calculated completion space: {}", completion_space)
+ return completion_space
def calculate_max_tokens(self, model: WrappedChatOpenAI, *prompts: str) -> None:
+ logger.info("Calculating max tokens for model: {} with prompts: {}", model.model_name, prompts)
requested_tokens = self.get_completion_space(model.model_name, *prompts)
+ if model.max_tokens is None:
+ model.max_tokens = requested_tokens
+
model.max_tokens = min(model.max_tokens, requested_tokens)
model.max_tokens = max(model.max_tokens, 1)
+ logger.info("Calculated max tokens for model: {}", model.max_tokens)
diff --git a/platform/reworkd_platform/settings.py b/platform/reworkd_platform/settings.py
index fe3563d82b..8dd70f87ab 100644
--- a/platform/reworkd_platform/settings.py
+++ b/platform/reworkd_platform/settings.py
@@ -53,7 +53,6 @@ class Settings(BaseSettings):
# OpenAI
openai_api_base: str = "https://api.openai.com/v1"
- openai_api_key: str = ""
openai_api_version: str = "2023-08-01-preview"
azure_openai_deployment_name: str = ""
@@ -64,6 +63,10 @@ class Settings(BaseSettings):
replicate_api_key: Optional[str] = None
serp_api_key: Optional[str] = None
+ # Ollama
+ ollama_api_base: str = "https://api.ollama.com/v1"
+ ollama_api_key: Optional[str] = None
+
# Frontend URL for CORS
frontend_url: str = "http://localhost:3000"
allowed_origins_regex: Optional[str] = None
diff --git a/platform/reworkd_platform/tests/agent/test_model_factory.py b/platform/reworkd_platform/tests/agent/test_model_factory.py
index f6579897a4..b937fbb40b 100644
--- a/platform/reworkd_platform/tests/agent/test_model_factory.py
+++ b/platform/reworkd_platform/tests/agent/test_model_factory.py
@@ -136,3 +136,28 @@ def test_custom_model_settings(model_settings: ModelSettings, streaming: bool):
assert model.model_name.startswith(model_settings.model)
assert model.max_tokens == model_settings.max_tokens
assert model.streaming == streaming
+
+
+def test_create_model_without_max_tokens():
+ user = UserBase(id="user_id")
+ settings = Settings()
+ model_settings = ModelSettings(
+ temperature=0.7,
+ model="gpt-3.5-turbo",
+ )
+
+ settings.openai_api_base = "https://api.openai.com"
+ settings.openai_api_key = "key"
+ settings.openai_api_version = "version"
+
+ result = create_model(settings, model_settings, user, streaming=False)
+ assert issubclass(result.__class__, WrappedChatOpenAI)
+ assert issubclass(result.__class__, ChatOpenAI)
+
+ # Check if the required keys are set properly
+ assert result.openai_api_base == settings.openai_api_base
+ assert result.openai_api_key == settings.openai_api_key
+ assert result.temperature == model_settings.temperature
+ assert result.max_tokens is None
+ assert result.streaming is False
+ assert result.max_retries == 5
diff --git a/platform/reworkd_platform/tests/test_token_service.py b/platform/reworkd_platform/tests/test_token_service.py
index 8609c0730e..e55b021b33 100644
--- a/platform/reworkd_platform/tests/test_token_service.py
+++ b/platform/reworkd_platform/tests/test_token_service.py
@@ -41,6 +41,17 @@ def test_calculate_max_tokens_with_small_max_tokens() -> None:
assert model.max_tokens == initial_max_tokens
+def test_calculate_max_tokens_with_none_max_tokens() -> None:
+ service = TokenService(encoding)
+ model = Mock(spec=["model_name", "max_tokens"])
+ model.model_name = "gpt-3.5-turbo"
+ model.max_tokens = None
+
+ service.calculate_max_tokens(model, "Hello")
+
+ assert model.max_tokens == LLM_MODEL_MAX_TOKENS.get("gpt-3.5-turbo")
+
+
def test_calculate_max_tokens_with_high_completion_tokens() -> None:
service = TokenService(encoding)
prompt_tokens = service.count(LONG_TEXT)
@@ -100,6 +111,4 @@ def test_calculate_max_tokens_with_negative_result() -> None:
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
-This is some long text. This is some long text. This is some long text.
-This is some long text. This is some long text. This is some long text.
"""
diff --git a/platform/reworkd_platform/web/api/agent/agent_service/ollama_agent_service.py b/platform/reworkd_platform/web/api/agent/agent_service/ollama_agent_service.py
new file mode 100644
index 0000000000..3b59a973fa
--- /dev/null
+++ b/platform/reworkd_platform/web/api/agent/agent_service/ollama_agent_service.py
@@ -0,0 +1,293 @@
+from typing import List, Optional
+
+from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
+from lanarky.responses import StreamingResponse
+from langchain.callbacks.base import AsyncCallbackHandler
+from langchain.output_parsers import PydanticOutputParser
+from langchain.prompts.chat import (
+ ChatPromptTemplate,
+ SystemMessagePromptTemplate,
+ HumanMessagePromptTemplate
+)
+from loguru import logger
+from pydantic import ValidationError
+
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.schemas.user import UserBase
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService
+from reworkd_platform.web.api.agent.analysis import Analysis, AnalysisArguments
+from reworkd_platform.web.api.agent.helpers import (
+ call_model_with_handling,
+ parse_with_handling,
+)
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.web.api.agent.prompts import (
+ analyze_task_prompt,
+ chat_prompt,
+ create_tasks_prompt,
+ start_goal_prompt,
+)
+from reworkd_platform.web.api.agent.task_output_parser import TaskOutputParser
+from reworkd_platform.web.api.agent.tools.open_ai_function import get_tool_function
+from reworkd_platform.web.api.agent.tools.tools import (
+ get_default_tool,
+ get_tool_from_name,
+ get_tool_name,
+ get_user_tools,
+)
+from reworkd_platform.web.api.agent.tools.utils import summarize
+from ollama import AsyncClient
+
+
+class OllamaAgentService(AgentService):
+ def __init__(
+ self,
+ model: WrappedChatOpenAI,
+ settings: ModelSettings,
+ token_service: TokenService,
+ callbacks: Optional[List[AsyncCallbackHandler]],
+ user: UserBase,
+ oauth_crud: OAuthCrud,
+ ):
+ self.model = model
+ self.settings = settings
+ self.token_service = token_service
+ self.callbacks = callbacks
+ self.user = user
+ self.oauth_crud = oauth_crud
+ self.client = AsyncClient(host='http://localhost:11434')
+
+ async def start_goal_agent(self, *, goal: str) -> List[str]:
+ prompt = ChatPromptTemplate.from_messages(
+ [SystemMessagePromptTemplate(prompt=start_goal_prompt)]
+ )
+
+ self.token_service.calculate_max_tokens(
+ self.model,
+ prompt.format_prompt(
+ goal=goal,
+ language=self.settings.language,
+ ).to_string(),
+ )
+
+ completion = await call_model_with_handling(
+ self.model,
+ ChatPromptTemplate.from_messages(
+ [SystemMessagePromptTemplate(prompt=start_goal_prompt)]
+ ),
+ {"goal": goal, "language": self.settings.language},
+ settings=self.settings,
+ callbacks=self.callbacks,
+ )
+
+ task_output_parser = TaskOutputParser(completed_tasks=[])
+ tasks = parse_with_handling(task_output_parser, completion)
+
+ return tasks
+
+ async def analyze_task_agent(
+ self, *, goal: str, task: str, tool_names: List[str]
+ ) -> Analysis:
+ user_tools = await get_user_tools(tool_names, self.user, self.oauth_crud)
+ functions = list(map(get_tool_function, user_tools))
+ prompt = analyze_task_prompt.format_prompt(
+ goal=goal,
+ task=task,
+ language=self.settings.language,
+ )
+
+ self.token_service.calculate_max_tokens(
+ self.model,
+ prompt.to_string(),
+ str(functions),
+ )
+
+ message = await self.client.chat(
+ model="llama3.2",
+ messages=prompt.to_messages(),
+ functions=functions,
+ settings=self.settings,
+ callbacks=self.callbacks,
+ )
+
+ function_call = message.additional_kwargs.get("function_call", {})
+ completion = function_call.get("arguments", "")
+
+ try:
+ pydantic_parser = PydanticOutputParser(pydantic_object=AnalysisArguments)
+ analysis_arguments = parse_with_handling(pydantic_parser, completion)
+ return Analysis(
+ action=function_call.get("name", get_tool_name(get_default_tool())),
+ **analysis_arguments.dict(),
+ )
+ except (ValidationError):
+ return Analysis.get_default_analysis(task)
+
+ async def execute_task_agent(
+ self,
+ *,
+ goal: str,
+ task: str,
+ analysis: Analysis,
+ ) -> StreamingResponse:
+ if self.model.max_tokens and self.model.max_tokens > 3000:
+ self.model.max_tokens = max(self.model.max_tokens - 1000, 3000)
+
+ tool_class = get_tool_from_name(analysis.action)
+ return await tool_class(self.model, self.settings.language).call(
+ goal,
+ task,
+ analysis.arg,
+ self.user,
+ self.oauth_crud,
+ )
+
+ async def create_tasks_agent(
+ self,
+ *,
+ goal: str,
+ tasks: List[str],
+ last_task: str,
+ result: str,
+ completed_tasks: Optional[List[str]] = None,
+ ) -> List[str]:
+ prompt = ChatPromptTemplate.from_messages(
+ [SystemMessagePromptTemplate(prompt=create_tasks_prompt)]
+ )
+
+ args = {
+ "goal": goal,
+ "language": self.settings.language,
+ "tasks": "\n".join(tasks),
+ "lastTask": last_task,
+ "result": result,
+ }
+
+ self.token_service.calculate_max_tokens(
+ self.model, prompt.format_prompt(**args).to_string()
+ )
+
+ completion = await call_model_with_handling(
+ self.model, prompt, args, settings=self.settings, callbacks=self.callbacks
+ )
+
+ previous_tasks = (completed_tasks or []) + tasks
+ return [completion] if completion not in previous_tasks else []
+
+ async def summarize_task_agent(
+ self,
+ *,
+ goal: str,
+ results: List[str],
+ ) -> FastAPIStreamingResponse:
+ self.model.model_name = "llama3.2"
+ self.model.max_tokens = 8000
+
+ snippet_max_tokens = 7000
+ text_tokens = self.token_service.tokenize("".join(results))
+ text = self.token_service.detokenize(text_tokens[0:snippet_max_tokens])
+ logger.info(f"Summarizing text: {text}")
+
+ return await summarize(
+ client=self.client,
+ language=self.settings.language,
+ goal=goal,
+ text=text,
+ )
+
+ async def chat(
+ self,
+ *,
+ message: str,
+ results: List[str],
+ ) -> FastAPIStreamingResponse:
+ self.model.model_name = "llama3.2"
+ prompt = ChatPromptTemplate.from_messages(
+ [
+ SystemMessagePromptTemplate(prompt=chat_prompt),
+ *[HumanMessagePromptTemplate.from_template(result) for result in results],
+ HumanMessagePromptTemplate.from_template(message),
+ ]
+ )
+
+ self.token_service.calculate_max_tokens(
+ self.model,
+ prompt.format_prompt(
+ language=self.settings.language,
+ ).to_string(),
+ )
+
+ formatted_prompt = prompt.format_prompt(language=self.settings.language)
+ messages = [
+ {'role': getattr(msg, 'role', 'system'), 'content': getattr(msg, 'content')}
+ for msg in formatted_prompt.to_messages()
+ ]
+
+ try:
+ response = await self.client.chat(
+ model="llama3.2",
+ messages=messages,
+ stream=True,
+ )
+ except Exception as e:
+ logger.exception("Error during Ollama chat request.")
+ raise
+
+ async def stream_response():
+ async for chunk in response:
+ if 'message' in chunk and 'content' in chunk['message']:
+ yield chunk['message']['content']
+
+ return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream")
+
+ async def pip_start_goal_agent(self, *, goal: str) -> List[str]:
+ return await self.start_goal_agent(goal=goal)
+
+ async def pip_analyze_task_agent(
+ self, *, goal: str, task: str, tool_names: List[str]
+ ) -> Analysis:
+ return await self.analyze_task_agent(goal=goal, task=task, tool_names=tool_names)
+
+ async def pip_execute_task_agent(
+ self,
+ *,
+ goal: str,
+ task: str,
+ analysis: Analysis,
+ ) -> StreamingResponse:
+ return await self.execute_task_agent(goal=goal, task=task, analysis=analysis)
+
+ async def pip_create_tasks_agent(
+ self,
+ *,
+ goal: str,
+ tasks: List[str],
+ last_task: str,
+ result: str,
+ completed_tasks: Optional[List[str]] = None,
+ ) -> List[str]:
+ return await self.create_tasks_agent(
+ goal=goal,
+ tasks=tasks,
+ last_task=last_task,
+ result=result,
+ completed_tasks=completed_tasks,
+ )
+
+ async def pip_summarize_task_agent(
+ self,
+ *,
+ goal: str,
+ results: List[str],
+ ) -> FastAPIStreamingResponse:
+ return await self.summarize_task_agent(goal=goal, results=results)
+
+ async def pip_chat(
+ self,
+ *,
+ message: str,
+ results: List[str],
+ ) -> FastAPIStreamingResponse:
+ return await self.chat(message=message, results=results)
diff --git a/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py b/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py
index 3221a3fbfc..55b6604751 100644
--- a/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py
+++ b/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py
@@ -2,11 +2,13 @@
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from lanarky.responses import StreamingResponse
-from langchain import LLMChain
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.output_parsers import PydanticOutputParser
-from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
-from langchain.schema import HumanMessage
+from langchain.prompts.chat import (
+ ChatPromptTemplate,
+ SystemMessagePromptTemplate,
+ HumanMessagePromptTemplate
+)
from loguru import logger
from pydantic import ValidationError
@@ -38,6 +40,7 @@
)
from reworkd_platform.web.api.agent.tools.utils import summarize
from reworkd_platform.web.api.errors import OpenAIError
+from ollama import AsyncClient # Updated import
class OpenAIAgentService(AgentService):
@@ -56,8 +59,12 @@ def __init__(
self.callbacks = callbacks
self.user = user
self.oauth_crud = oauth_crud
+ # Initialize the Async Ollama client once
+ self.client = AsyncClient(host='http://localhost:11434') # Use environment variables for flexibility
+ logger.info("OpenAIAgentService initialized with model: {}, settings: {}, user: {}", model, settings, user)
async def start_goal_agent(self, *, goal: str) -> List[str]:
+ logger.info("Starting goal agent with goal: {}", goal)
prompt = ChatPromptTemplate.from_messages(
[SystemMessagePromptTemplate(prompt=start_goal_prompt)]
)
@@ -82,12 +89,13 @@ async def start_goal_agent(self, *, goal: str) -> List[str]:
task_output_parser = TaskOutputParser(completed_tasks=[])
tasks = parse_with_handling(task_output_parser, completion)
-
+ logger.info("Goal agent completed with tasks: {}", tasks)
return tasks
async def analyze_task_agent(
self, *, goal: str, task: str, tool_names: List[str]
) -> Analysis:
+ logger.info("Analyzing task with goal: {}, task: {}, tool_names: {}", goal, task, tool_names)
user_tools = await get_user_tools(tool_names, self.user, self.oauth_crud)
functions = list(map(get_tool_function, user_tools))
prompt = analyze_task_prompt.format_prompt(
@@ -116,11 +124,14 @@ async def analyze_task_agent(
try:
pydantic_parser = PydanticOutputParser(pydantic_object=AnalysisArguments)
analysis_arguments = parse_with_handling(pydantic_parser, completion)
- return Analysis(
+ analysis = Analysis(
action=function_call.get("name", get_tool_name(get_default_tool())),
**analysis_arguments.dict(),
)
- except (OpenAIError, ValidationError):
+ logger.info("Task analysis completed: {}", analysis)
+ return analysis
+ except (OpenAIError, ValidationError) as e:
+ logger.error("Error during task analysis: {}", e)
return Analysis.get_default_analysis(task)
async def execute_task_agent(
@@ -130,18 +141,21 @@ async def execute_task_agent(
task: str,
analysis: Analysis,
) -> StreamingResponse:
+ logger.info("Executing task with goal: {}, task: {}, analysis: {}", goal, task, analysis)
# TODO: More mature way of calculating max_tokens
- if self.model.max_tokens > 3000:
+ if self.model.max_tokens and self.model.max_tokens > 3000:
self.model.max_tokens = max(self.model.max_tokens - 1000, 3000)
tool_class = get_tool_from_name(analysis.action)
- return await tool_class(self.model, self.settings.language).call(
+ response = await tool_class(self.model, self.settings.language).call(
goal,
task,
analysis.arg,
self.user,
self.oauth_crud,
)
+ logger.info("Task execution completed with response: {}", response)
+ return response
async def create_tasks_agent(
self,
@@ -152,6 +166,7 @@ async def create_tasks_agent(
result: str,
completed_tasks: Optional[List[str]] = None,
) -> List[str]:
+ logger.info("Creating tasks with goal: {}, tasks: {}, last_task: {}, result: {}", goal, tasks, last_task, result)
prompt = ChatPromptTemplate.from_messages(
[SystemMessagePromptTemplate(prompt=create_tasks_prompt)]
)
@@ -173,7 +188,9 @@ async def create_tasks_agent(
)
previous_tasks = (completed_tasks or []) + tasks
- return [completion] if completion not in previous_tasks else []
+ new_tasks = [completion] if completion not in previous_tasks else []
+ logger.info("Tasks created: {}", new_tasks)
+ return new_tasks
async def summarize_task_agent(
self,
@@ -181,20 +198,23 @@ async def summarize_task_agent(
goal: str,
results: List[str],
) -> FastAPIStreamingResponse:
- self.model.model_name = "gpt-3.5-turbo-16k"
+ logger.info("Summarizing task with goal: {}, results: {}", goal, results)
+ self.model.model_name = "llama3.2"
self.model.max_tokens = 8000 # Total tokens = prompt tokens + completion tokens
snippet_max_tokens = 7000 # Leave room for the rest of the prompt
text_tokens = self.token_service.tokenize("".join(results))
text = self.token_service.detokenize(text_tokens[0:snippet_max_tokens])
- logger.info(f"Summarizing text: {text}")
+ logger.info("Summarizing text: {}", text)
- return summarize(
- model=self.model,
+ response = await summarize(
+ client=self.client, # Pass the initialized AsyncClient
language=self.settings.language,
goal=goal,
text=text,
)
+ logger.info("Task summary completed with response: {}", response)
+ return response
async def chat(
self,
@@ -202,15 +222,19 @@ async def chat(
message: str,
results: List[str],
) -> FastAPIStreamingResponse:
- self.model.model_name = "gpt-3.5-turbo-16k"
+ logger.info("Chatting with message: {}, results: {}", message, results)
+ self.model.model_name = "llama3.2"
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate(prompt=chat_prompt),
- *[HumanMessage(content=result) for result in results],
- HumanMessage(content=message),
+ *[HumanMessagePromptTemplate.from_template(result) for result in results],
+ HumanMessagePromptTemplate.from_template(message),
]
)
+ if self.model.max_tokens is None:
+ self.model.max_tokens = 0
+
self.token_service.calculate_max_tokens(
self.model,
prompt.format_prompt(
@@ -218,10 +242,83 @@ async def chat(
).to_string(),
)
- chain = LLMChain(llm=self.model, prompt=prompt)
+ # Format the prompt and extract messages
+ formatted_prompt = prompt.format_prompt(language=self.settings.language)
+ messages = [
+ {'role': getattr(msg, 'role', 'system'), 'content': getattr(msg, 'content')}
+ for msg in formatted_prompt.to_messages()
+ ]
+
+ try:
+ # Make the chat request with streaming
+ response = await self.client.chat(
+ model="llama3.2",
+ messages=messages,
+ stream=True,
+ )
+ logger.info("Chat request successful with response: {}", response)
+ except Exception as e:
+ logger.exception("Error during Ollama chat request.")
+ # Handle specific exceptions if necessary
+ raise
- return StreamingResponse.from_chain(
- chain,
- {"language": self.settings.language},
- media_type="text/event-stream",
+ # Define an asynchronous generator to yield streamed responses
+ async def stream_response():
+ async for chunk in response:
+ # Ensure 'message' and 'content' keys exist
+ if 'message' in chunk and 'content' in chunk['message']:
+ yield chunk['message']['content']
+
+ return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream")
+
+ # The remaining methods remain unchanged but ensure that any usage of 'Ollama' is replaced accordingly.
+
+ async def pip_start_goal_agent(self, *, goal: str) -> List[str]:
+ return await self.start_goal_agent(goal=goal)
+
+ async def pip_analyze_task_agent(
+ self, *, goal: str, task: str, tool_names: List[str]
+ ) -> Analysis:
+ return await self.analyze_task_agent(goal=goal, task=task, tool_names=tool_names)
+
+ async def pip_execute_task_agent(
+ self,
+ *,
+ goal: str,
+ task: str,
+ analysis: Analysis,
+ ) -> StreamingResponse:
+ return await self.execute_task_agent(goal=goal, task=task, analysis=analysis)
+
+ async def pip_create_tasks_agent(
+ self,
+ *,
+ goal: str,
+ tasks: List[str],
+ last_task: str,
+ result: str,
+ completed_tasks: Optional[List[str]] = None,
+ ) -> List[str]:
+ return await self.create_tasks_agent(
+ goal=goal,
+ tasks=tasks,
+ last_task=last_task,
+ result=result,
+ completed_tasks=completed_tasks,
)
+
+ async def pip_summarize_task_agent(
+ self,
+ *,
+ goal: str,
+ results: List[str],
+ ) -> FastAPIStreamingResponse:
+ return await self.summarize_task_agent(goal=goal, results=results)
+
+ async def pip_chat(
+ self,
+ *,
+ message: str,
+ results: List[str],
+ ) -> FastAPIStreamingResponse:
+ return await self.chat(message=message, results=results)
diff --git a/platform/reworkd_platform/web/api/agent/model_factory.py b/platform/reworkd_platform/web/api/agent/model_factory.py
index 52644a9143..8565c0eb2f 100644
--- a/platform/reworkd_platform/web/api/agent/model_factory.py
+++ b/platform/reworkd_platform/web/api/agent/model_factory.py
@@ -1,11 +1,12 @@
from typing import Any, Dict, Optional, Tuple, Type, Union
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
-from pydantic import Field
+from pydantic import Field, ValidationError
from reworkd_platform.schemas.agent import LLM_Model, ModelSettings
from reworkd_platform.schemas.user import UserBase
from reworkd_platform.settings import Settings
+from reworkd_platform.logging import logger
class WrappedChatOpenAI(ChatOpenAI):
@@ -13,8 +14,9 @@ class WrappedChatOpenAI(ChatOpenAI):
default=None,
description="Meta private value but mypy will complain its missing",
)
- max_tokens: int
+ max_tokens: Optional[int] = None
model_name: LLM_Model = Field(alias="model")
+ ollama_api_key: Optional[str] = None
class WrappedAzureChatOpenAI(AzureChatOpenAI, WrappedChatOpenAI):
@@ -33,6 +35,8 @@ def create_model(
streaming: bool = False,
force_model: Optional[LLM_Model] = None,
) -> WrappedChat:
+ logger.info("Creating model with settings: %s, model_settings: %s, user: %s, streaming: %s, force_model: %s",
+ settings, model_settings, user, streaming, force_model)
use_azure = (
not model_settings.custom_api_key and "azure" in settings.openai_api_base
)
@@ -40,9 +44,9 @@ def create_model(
llm_model = force_model or model_settings.model
model: Type[WrappedChat] = WrappedChatOpenAI
base, headers, use_helicone = get_base_and_headers(settings, model_settings, user)
+
kwargs = {
"openai_api_base": base,
- "openai_api_key": model_settings.custom_api_key or settings.openai_api_key,
"temperature": model_settings.temperature,
"model": llm_model,
"max_tokens": model_settings.max_tokens,
@@ -66,6 +70,10 @@ def create_model(
if use_helicone:
kwargs["model"] = deployment_name
+ if settings.ollama_api_key:
+ kwargs["ollama_api_key"] = settings.ollama_api_key
+
+ logger.info("Model created with kwargs: %s", kwargs)
return model(**kwargs) # type: ignore
diff --git a/platform/reworkd_platform/web/api/agent/tools/code.py b/platform/reworkd_platform/web/api/agent/tools/code.py
index 1b0053f61e..ebda51fb0c 100644
--- a/platform/reworkd_platform/web/api/agent/tools/code.py
+++ b/platform/reworkd_platform/web/api/agent/tools/code.py
@@ -2,7 +2,7 @@
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from lanarky.responses import StreamingResponse
-from langchain import LLMChain
+from ollama import Client # Updated import
from reworkd_platform.web.api.agent.tools.tool import Tool
@@ -16,10 +16,19 @@ async def call(
) -> FastAPIStreamingResponse:
from reworkd_platform.web.api.agent.prompts import code_prompt
- chain = LLMChain(llm=self.model, prompt=code_prompt)
-
- return StreamingResponse.from_chain(
- chain,
- {"goal": goal, "language": self.language, "task": task},
- media_type="text/event-stream",
+ client = Client(host='http://localhost:11434') # Specify host if different
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": code_prompt},
+ {"role": "user", "content": input_str}
+ ],
+ stream=True, # Enable streaming if required
)
+
+ # Create a generator to yield streaming responses
+ async def stream_response():
+ for chunk in response:
+ yield chunk['message']['content']
+
+ return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream")
diff --git a/platform/reworkd_platform/web/api/agent/tools/image.py b/platform/reworkd_platform/web/api/agent/tools/image.py
index f93444875e..58aeb690fa 100644
--- a/platform/reworkd_platform/web/api/agent/tools/image.py
+++ b/platform/reworkd_platform/web/api/agent/tools/image.py
@@ -1,47 +1,10 @@
from typing import Any
-import openai
-import replicate
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
-from replicate.exceptions import ModelError
-from replicate.exceptions import ReplicateError as ReplicateAPIError
+from ollama import Client # Updated import
-from reworkd_platform.settings import settings
from reworkd_platform.web.api.agent.stream_mock import stream_string
from reworkd_platform.web.api.agent.tools.tool import Tool
-from reworkd_platform.web.api.errors import ReplicateError
-
-
-async def get_replicate_image(input_str: str) -> str:
- if settings.replicate_api_key is None or settings.replicate_api_key == "":
- raise RuntimeError("Replicate API key not set")
-
- client = replicate.Client(settings.replicate_api_key)
- try:
- output = client.run(
- "stability-ai/stable-diffusion"
- ":db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
- input={"prompt": input_str},
- image_dimensions="512x512",
- )
- except ModelError as e:
- raise ReplicateError(e, "Image generation failed due to NSFW image.")
- except ReplicateAPIError as e:
- raise ReplicateError(e, "Failed to generate an image.")
-
- return output[0]
-
-
-# Use AI to generate an Image based on a prompt
-async def get_open_ai_image(input_str: str) -> str:
- response = openai.Image.create(
- api_key=settings.openai_api_key,
- prompt=input_str,
- n=1,
- size="256x256",
- )
-
- return response["data"][0]["url"]
class Image(Tool):
@@ -52,15 +15,30 @@ class Image(Tool):
"This should be a detailed description of the image touching on image "
"style, image focus, color, etc."
)
- image_url = "/tools/replicate.png"
+ image_url = "/tools/ollama.png"
async def call(
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
) -> FastAPIStreamingResponse:
- # Use the replicate API if its available, otherwise use DALL-E
- try:
- url = await get_replicate_image(input_str)
- except RuntimeError:
- url = await get_open_ai_image(input_str)
+ client = Client(host='http://localhost:11434') # Specify host if different
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": "Generate an image based on the following description."},
+ {"role": "user", "content": input_str}
+ ],
+ stream=True,
+ )
+
+ # Assuming 'chain' returns a URL or some identifier for the generated image
+ image_url = ""
+
+ async def stream_response():
+ nonlocal image_url
+ for chunk in response:
+ content = chunk['message']['content']
+ image_url += content # Adjust based on actual response structure
+
+ await stream_response()
- return stream_string(f"![{input_str}]({url})")
+ return stream_string(f"![{input_str}]({image_url})")
diff --git a/platform/reworkd_platform/web/api/agent/tools/search.py b/platform/reworkd_platform/web/api/agent/tools/search.py
index fb83e8213b..31de32d6d9 100644
--- a/platform/reworkd_platform/web/api/agent/tools/search.py
+++ b/platform/reworkd_platform/web/api/agent/tools/search.py
@@ -1,10 +1,10 @@
+import json
from typing import Any, List
from urllib.parse import quote
-import aiohttp
-from aiohttp import ClientResponseError
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from loguru import logger
+from ollama import Client # Updated import
from reworkd_platform.settings import settings
from reworkd_platform.web.api.agent.stream_mock import stream_string
@@ -15,28 +15,7 @@
summarize_with_sources,
)
-# Search google via serper.dev. Adapted from LangChain
-# https://github.com/hwchase17/langchain/blob/master/langchain/utilities
-
-
-async def _google_serper_search_results(
- search_term: str, search_type: str = "search"
-) -> dict[str, Any]:
- headers = {
- "X-API-KEY": settings.serp_api_key or "",
- "Content-Type": "application/json",
- }
- params = {
- "q": search_term,
- }
-
- async with aiohttp.ClientSession() as session:
- async with session.post(
- f"https://google.serper.dev/{search_type}", headers=headers, params=params
- ) as response:
- response.raise_for_status()
- search_results = await response.json()
- return search_results
+# Search Google via Ollama model
class Search(Tool):
@@ -44,7 +23,7 @@ class Search(Tool):
"Search Google for short up to date searches for simple questions about public information "
"news and people.\n"
)
- public_description = "Search google for information about current events."
+ public_description = "Search Google for information about current events."
arg_description = "The query argument to search for. This value is always populated and cannot be an empty string."
image_url = "/tools/google.png"
@@ -57,8 +36,8 @@ async def call(
) -> FastAPIStreamingResponse:
try:
return await self._call(goal, task, input_str, *args, **kwargs)
- except ClientResponseError:
- logger.exception("Error calling Serper API, falling back to reasoning")
+ except Exception:
+ logger.exception("Error calling Ollama model, falling back to reasoning")
return await Reason(self.model, self.language).call(
goal, task, input_str, *args, **kwargs
)
@@ -66,42 +45,59 @@ async def call(
async def _call(
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
) -> FastAPIStreamingResponse:
- results = await _google_serper_search_results(
- input_str,
+ client = Client(host='http://localhost:11434') # Specify host if different
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": "Perform a Google search based on the following query."},
+ {"role": "user", "content": input_str}
+ ],
+ stream=True,
)
- k = 5 # Number of results to return
snippets: List[CitedSnippet] = []
- if results.get("answerBox"):
- answer_values = []
- answer_box = results.get("answerBox", {})
- if answer_box.get("answer"):
- answer_values.append(answer_box.get("answer"))
- elif answer_box.get("snippet"):
- answer_values.append(answer_box.get("snippet").replace("\n", " "))
- elif answer_box.get("snippetHighlighted"):
- answer_values.append(", ".join(answer_box.get("snippetHighlighted")))
-
- if len(answer_values) > 0:
- snippets.append(
- CitedSnippet(
- len(snippets) + 1,
- "\n".join(answer_values),
- f"https://www.google.com/search?q={quote(input_str)}",
- )
- )
-
- for i, result in enumerate(results["organic"][:k]):
- texts = []
- link = ""
- if "snippet" in result:
- texts.append(result["snippet"])
- if "link" in result:
- link = result["link"]
- for attribute, value in result.get("attributes", {}).items():
- texts.append(f"{attribute}: {value}.")
- snippets.append(CitedSnippet(len(snippets) + 1, "\n".join(texts), link))
+ async def process_response():
+ nonlocal snippets
+ for chunk in response:
+ message_content = chunk['message']['content']
+ # Assuming the API returns JSON-like responses
+ try:
+ results = json.loads(message_content)
+ except json.JSONDecodeError:
+ continue # Handle or log the error as needed
+
+ if results.get("answerBox"):
+ answer_values = []
+ answer_box = results.get("answerBox", {})
+ if answer_box.get("answer"):
+ answer_values.append(answer_box.get("answer"))
+ elif answer_box.get("snippet"):
+ answer_values.append(answer_box.get("snippet").replace("\n", " "))
+ elif answer_box.get("snippetHighlighted"):
+ answer_values.append(", ".join(answer_box.get("snippetHighlighted")))
+
+ if len(answer_values) > 0:
+ snippets.append(
+ CitedSnippet(
+ len(snippets) + 1,
+ "\n".join(answer_values),
+ f"https://www.google.com/search?q={quote(input_str)}",
+ )
+ )
+
+ for result in results.get("organic", [])[:5]:
+ texts = []
+ link = ""
+ if "snippet" in result:
+ texts.append(result["snippet"])
+ if "link" in result:
+ link = result["link"]
+ for attribute, value in result.get("attributes", {}).items():
+ texts.append(f"{attribute}: {value}.")
+ snippets.append(CitedSnippet(len(snippets) + 1, "\n".join(texts), link))
+
+ await process_response()
if len(snippets) == 0:
return stream_string("No good Google Search Result was found", True)
diff --git a/platform/reworkd_platform/web/api/agent/tools/sidsearch.py b/platform/reworkd_platform/web/api/agent/tools/sidsearch.py
index e99453cd77..9efb4c61c8 100644
--- a/platform/reworkd_platform/web/api/agent/tools/sidsearch.py
+++ b/platform/reworkd_platform/web/api/agent/tools/sidsearch.py
@@ -2,9 +2,9 @@
from datetime import datetime, timedelta
from typing import Any, List, Optional
-import aiohttp
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from loguru import logger
+from ollama import Client # Updated import
from reworkd_platform.db.crud.oauth import OAuthCrud
from reworkd_platform.db.models.auth import OauthCredentials
@@ -24,15 +24,29 @@ async def _sid_search_results(
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
data = {"query": search_term, "limit": limit}
- async with aiohttp.ClientSession() as session:
- async with session.post(
- "https://api.sid.ai/v1/users/me/query",
- headers=headers,
- data=json.dumps(data),
- ) as response:
- response.raise_for_status()
- search_results = await response.json()
- return search_results
+ client = Client(host='http://localhost:11434') # Specify host if different
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": "Search through personal data sources."},
+ {"role": "user", "content": search_term}
+ ],
+ stream=True,
+ )
+
+ search_results = {}
+ async def process_response():
+ nonlocal search_results
+ for chunk in response:
+ message_content = chunk['message']['content']
+ try:
+ data = json.loads(message_content)
+ search_results.update(data)
+ except json.JSONDecodeError:
+ continue # Handle or log the error as needed
+
+ await process_response()
+ return search_results
async def token_exchange(refresh_token: str) -> tuple[str, datetime]:
@@ -43,14 +57,31 @@ async def token_exchange(refresh_token: str) -> tuple[str, datetime]:
"redirect_uri": settings.sid_redirect_uri,
"refresh_token": refresh_token,
}
- async with aiohttp.ClientSession() as session:
- async with session.post(
- "https://auth.sid.ai/oauth/token", data=data
- ) as response:
- response.raise_for_status()
- response_data = await response.json()
- access_token = response_data["access_token"]
- expires_in = response_data["expires_in"]
+ client = Client(host='http://localhost:11434') # Specify host if different
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": "Exchange refresh token for access token."},
+ {"role": "user", "content": json.dumps(data)}
+ ],
+ stream=True,
+ )
+
+ response_data = {}
+ async def process_response():
+ nonlocal response_data
+ for chunk in response:
+ message_content = chunk['message']['content']
+ try:
+ data = json.loads(message_content)
+ response_data.update(data)
+ except json.JSONDecodeError:
+ continue # Handle or log the error as needed
+
+ await process_response()
+
+ access_token = response_data.get("access_token")
+ expires_in = response_data.get("expires_in")
return access_token, datetime.now() + timedelta(seconds=expires_in)
@@ -126,7 +157,6 @@ async def _run_sid(
return summarize_sid(self.model, self.language, goal, task, snippets)
-
async def call(
self,
goal: str,
@@ -137,7 +167,7 @@ async def call(
*args: Any,
**kwargs: Any,
) -> FastAPIStreamingResponse:
- # fall back to search if no results are found
+ # fall back to search if no results are found
return await self._run_sid(goal, task, input_str, user, oauth_crud) or await Search(self.model, self.language).call(
- goal, task, input_str, user, oauth_crud
- )
+ goal, task, input_str, user, oauth_crud
+ )
diff --git a/platform/reworkd_platform/web/api/agent/tools/utils.py b/platform/reworkd_platform/web/api/agent/tools/utils.py
index 5a1b7755a4..850f4278de 100644
--- a/platform/reworkd_platform/web/api/agent/tools/utils.py
+++ b/platform/reworkd_platform/web/api/agent/tools/utils.py
@@ -1,11 +1,9 @@
from dataclasses import dataclass
-from typing import List
+from typing import List, AsyncGenerator
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from lanarky.responses import StreamingResponse
-from langchain import LLMChain
-from langchain.chat_models.base import BaseChatModel
-
+from ollama import Client # Updated import
@dataclass
class CitedSnippet:
@@ -31,29 +29,32 @@ def __repr__(self) -> str:
return f"{{text: {self.text}}}"
-def summarize(
- model: BaseChatModel,
+async def summarize(
+ client: Client,
language: str,
goal: str,
text: str,
) -> FastAPIStreamingResponse:
from reworkd_platform.web.api.agent.prompts import summarize_prompt
- chain = LLMChain(llm=model, prompt=summarize_prompt)
-
- return StreamingResponse.from_chain(
- chain,
- {
- "goal": goal,
- "language": language,
- "text": text,
- },
- media_type="text/event-stream",
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": summarize_prompt},
+ {"role": "user", "content": text}
+ ],
+ stream=True,
)
+ async def stream_response():
+ for chunk in response:
+ yield chunk['message']['content']
+
+ return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream")
-def summarize_with_sources(
- model: BaseChatModel,
+
+async def summarize_with_sources(
+ client: Client,
language: str,
goal: str,
query: str,
@@ -61,22 +62,26 @@ def summarize_with_sources(
) -> FastAPIStreamingResponse:
from reworkd_platform.web.api.agent.prompts import summarize_with_sources_prompt
- chain = LLMChain(llm=model, prompt=summarize_with_sources_prompt)
-
- return StreamingResponse.from_chain(
- chain,
- {
- "goal": goal,
- "query": query,
- "language": language,
- "snippets": snippets,
- },
- media_type="text/event-stream",
+ combined_snippets = "\n".join([snippet.text for snippet in snippets])
+
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": summarize_with_sources_prompt},
+ {"role": "user", "content": combined_snippets}
+ ],
+ stream=True,
)
+ async def stream_response():
+ for chunk in response:
+ yield chunk['message']['content']
-def summarize_sid(
- model: BaseChatModel,
+ return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream")
+
+
+async def summarize_sid(
+ client: Client,
language: str,
goal: str,
query: str,
@@ -84,15 +89,19 @@ def summarize_sid(
) -> FastAPIStreamingResponse:
from reworkd_platform.web.api.agent.prompts import summarize_sid_prompt
- chain = LLMChain(llm=model, prompt=summarize_sid_prompt)
-
- return StreamingResponse.from_chain(
- chain,
- {
- "goal": goal,
- "query": query,
- "language": language,
- "snippets": snippets,
- },
- media_type="text/event-stream",
+ combined_snippets = "\n".join([snippet.text for snippet in snippets])
+
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": summarize_sid_prompt},
+ {"role": "user", "content": combined_snippets}
+ ],
+ stream=True,
)
+
+ async def stream_response():
+ for chunk in response:
+ yield chunk['message']['content']
+
+ return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream")
diff --git a/platform/reworkd_platform/web/api/dependencies.py b/platform/reworkd_platform/web/api/dependencies.py
index 95463373e4..2932d5cfa7 100644
--- a/platform/reworkd_platform/web/api/dependencies.py
+++ b/platform/reworkd_platform/web/api/dependencies.py
@@ -10,6 +10,7 @@
from reworkd_platform.db.dependencies import get_db_session
from reworkd_platform.schemas.user import UserBase
from reworkd_platform.web.api.http_responses import forbidden
+from reworkd_platform.logging import logger
def user_crud(
@@ -28,11 +29,14 @@ async def get_current_user(
try:
session = await crud.get_user_session(session_token)
except NoResultFound:
+ logger.error("Invalid session token")
raise forbidden("Invalid session token")
if session.expires <= datetime.utcnow():
+ logger.error("Session token expired")
raise forbidden("Session token expired")
+ logger.info(f"User {session.user.id} authenticated successfully")
return UserBase(
id=session.user.id,
name=session.user.name,
diff --git a/platform/setup.py b/platform/setup.py
new file mode 100644
index 0000000000..23c383fd10
--- /dev/null
+++ b/platform/setup.py
@@ -0,0 +1,65 @@
+from setuptools import setup, find_packages
+
+setup(
+ name="reworkd_platform",
+ version="0.1.0",
+ description="A platform for reworkd",
+ author="awtkns, asim-shrestha",
+ author_email="",
+ url="https://github.com/reworkd/AgentGPT",
+ packages=find_packages(),
+ install_requires=[
+ "fastapi==0.98.0",
+ "boto3==1.28.51",
+ "uvicorn[standard]==0.22.0",
+ "pydantic[dotenv]<2.0",
+ "ujson==5.8.0",
+ "sqlalchemy[mypy,asyncio]==2.0.21",
+ "aiomysql==0.1.1",
+ "mysqlclient==2.2.0",
+ "sentry-sdk==1.31.0",
+ "loguru==0.7.2",
+ "aiokafka==0.8.1",
+ "requests==2.31.0",
+ "langchain==0.0.295",
+ "openai==0.28.0",
+ "wikipedia==1.4.0",
+ "replicate==0.8.4",
+ "lanarky==0.7.15",
+ "tiktoken==0.5.1",
+ "grpcio==1.58.0",
+ "pinecone-client==2.2.4",
+ "python-multipart==0.0.6",
+ "aws-secretsmanager-caching==1.1.1.5",
+ "botocore==1.31.51",
+ "stripe==5.5.0",
+ "cryptography==41.0.4",
+ "httpx==0.25.0",
+ ],
+ extras_require={
+ "dev": [
+ "autopep8==2.0.4",
+ "pytest==7.4.2",
+ "flake8==6.0.0",
+ "mypy==1.5.1",
+ "isort==5.12.0",
+ "pre-commit==3.4.0",
+ "wemake-python-styleguide==0.18.0",
+ "black==23.9.1",
+ "autoflake==2.2.1",
+ "pytest-cov==4.1.0",
+ "anyio==3.7.1",
+ "pytest-env==0.8.2",
+ "dotmap==1.3.30",
+ "pytest-mock==3.10.0",
+ "pytest-asyncio==0.21.0",
+ "types-requests==2.31.0.1",
+ "types-pytz==2023.3.0.0",
+ ],
+ },
+ entry_points={
+ "console_scripts": [
+ "reworkd_platform=reworkd_platform.__main__:main",
+ ],
+ },
+)