From 5f1a5ba3811d12baca9e9cf741762f042ca86a7d Mon Sep 17 00:00:00 2001 From: Rodrigo Maldonado <32023142+rodrigo-92@users.noreply.github.com> Date: Mon, 18 Nov 2024 16:45:28 -0500 Subject: [PATCH] Improvement/sambanova integration sambastudio (#16930) --- ...{sambanovacloud.md => sambanovasystems.md} | 0 docs/docs/examples/llm/sambanova.ipynb | 325 ---- docs/docs/examples/llm/sambanovasystems.ipynb | 643 +++++++ docs/mkdocs.yml | 4 +- .../llms/sambanovacloud/__init__.py | 3 - .../llama_index/llms/sambanovacloud/base.py | 641 ------- .../.gitignore | 0 .../BUILD | 0 .../Makefile | 0 .../README.md | 14 +- .../llama_index/llms/sambanovasystems}/BUILD | 0 .../llms/sambanovasystems/__init__.py | 3 + .../llama_index/llms/sambanovasystems/base.py | 1570 +++++++++++++++++ .../pyproject.toml | 11 +- .../tests/BUILD | 0 .../tests/__init__.py | 0 .../tests/test_llms_sambanovasystems.py} | 120 +- 17 files changed, 2295 insertions(+), 1039 deletions(-) rename docs/docs/api_reference/llms/{sambanovacloud.md => sambanovasystems.md} (100%) delete mode 100644 docs/docs/examples/llm/sambanova.ipynb create mode 100644 docs/docs/examples/llm/sambanovasystems.ipynb delete mode 100644 llama-index-integrations/llms/llama-index-llms-sambanovacloud/llama_index/llms/sambanovacloud/__init__.py delete mode 100644 llama-index-integrations/llms/llama-index-llms-sambanovacloud/llama_index/llms/sambanovacloud/base.py rename llama-index-integrations/llms/{llama-index-llms-sambanovacloud => llama-index-llms-sambanovasystems}/.gitignore (100%) rename llama-index-integrations/llms/{llama-index-llms-sambanovacloud => llama-index-llms-sambanovasystems}/BUILD (100%) rename llama-index-integrations/llms/{llama-index-llms-sambanovacloud => llama-index-llms-sambanovasystems}/Makefile (100%) rename llama-index-integrations/llms/{llama-index-llms-sambanovacloud => llama-index-llms-sambanovasystems}/README.md (55%) rename llama-index-integrations/llms/{llama-index-llms-sambanovacloud/llama_index/llms/sambanovacloud => llama-index-llms-sambanovasystems/llama_index/llms/sambanovasystems}/BUILD (100%) create mode 100644 llama-index-integrations/llms/llama-index-llms-sambanovasystems/llama_index/llms/sambanovasystems/__init__.py create mode 100644 llama-index-integrations/llms/llama-index-llms-sambanovasystems/llama_index/llms/sambanovasystems/base.py rename llama-index-integrations/llms/{llama-index-llms-sambanovacloud => llama-index-llms-sambanovasystems}/pyproject.toml (79%) rename llama-index-integrations/llms/{llama-index-llms-sambanovacloud => llama-index-llms-sambanovasystems}/tests/BUILD (100%) rename llama-index-integrations/llms/{llama-index-llms-sambanovacloud => llama-index-llms-sambanovasystems}/tests/__init__.py (100%) rename llama-index-integrations/llms/{llama-index-llms-sambanovacloud/tests/test_llms_sambanova.py => llama-index-llms-sambanovasystems/tests/test_llms_sambanovasystems.py} (60%) diff --git a/docs/docs/api_reference/llms/sambanovacloud.md b/docs/docs/api_reference/llms/sambanovasystems.md similarity index 100% rename from docs/docs/api_reference/llms/sambanovacloud.md rename to docs/docs/api_reference/llms/sambanovasystems.md diff --git a/docs/docs/examples/llm/sambanova.ipynb b/docs/docs/examples/llm/sambanova.ipynb deleted file mode 100644 index b093259844708..0000000000000 --- a/docs/docs/examples/llm/sambanova.ipynb +++ /dev/null @@ -1,325 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# SambaNova Cloud\n", - "\n", - "This will help you getting started with **[SambaNova](https://sambanova.ai/)'s** [SambaNova Cloud](https://cloud.sambanova.ai/), which is a platform for performing inference with open-source models.\n", - "\n", - "## Setup\n", - "\n", - "To access SambaNova Cloud model you will need to create a [SambaNovaCloud](https://cloud.sambanova.ai/) account, get an API key, install the `llama-index-llms-sambanova` integration package, and install the `SSEClient` Package.\n", - "\n", - "```bash\n", - "pip install llama-index-llms-sambanovacloud\n", - "pip install sseclient-py\n", - "```\n", - "\n", - "### Credentials\n", - "\n", - "Get an API Key from [cloud.sambanova.ai](https://cloud.sambanova.ai/apis) and add it to your environment variables:\n", - "\n", - "``` bash\n", - "export SAMBANOVA_API_KEY=\"your-api-key-here\"\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import getpass\n", - "import os\n", - "\n", - "if not os.getenv(\"SAMBANOVA_API_KEY\"):\n", - " os.environ[\"SAMBANOVA_API_KEY\"] = getpass.getpass(\n", - " \"Enter your SambaNova Cloud API key: \"\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Installation\n", - "\n", - "The Llama-Index __SambaNovaCloud__ integration lives in the `langchain-index-integrations` package, and it can be installed with the following commands:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%pip install \"llama-index-llms-sambanovacloud\"\n", - "%pip install sseclient-py" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Instantiation\n", - "\n", - "Now we can instantiate our model object and generate chat completions:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from llama_index.llms.sambanovacloud import SambaNovaCloud\n", - "\n", - "llm = SambaNovaCloud(\n", - " model=\"Meta-Llama-3.1-70B-Instruct\",\n", - " max_tokens=1024,\n", - " temperature=0.7,\n", - " top_k=1,\n", - " top_p=0.01,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Invocation\n", - "\n", - "Given the following system and user messages, let's explore different ways of calling a SambaNova Cloud model. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from llama_index.core.base.llms.types import (\n", - " ChatMessage,\n", - " MessageRole,\n", - ")\n", - "\n", - "system_msg = ChatMessage(\n", - " role=MessageRole.SYSTEM,\n", - " content=\"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n", - ")\n", - "user_msg = ChatMessage(role=MessageRole.USER, content=\"I love programming.\")\n", - "\n", - "messages = [\n", - " system_msg,\n", - " user_msg,\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Chat" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ai_msg = llm.chat(messages)\n", - "ai_msg.message" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(ai_msg.message.content)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Complete" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ai_msg = llm.complete(user_msg.content)\n", - "ai_msg" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(ai_msg.text)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Streaming" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Chat" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ai_stream_msgs = []\n", - "for stream in llm.stream_chat(messages):\n", - " ai_stream_msgs.append(stream)\n", - "ai_stream_msgs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(ai_stream_msgs[-1])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Complete" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ai_stream_msgs = []\n", - "for stream in llm.stream_complete(user_msg.content):\n", - " ai_stream_msgs.append(stream)\n", - "ai_stream_msgs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(ai_stream_msgs[-1])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Async" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Chat" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ai_msg = await llm.achat(messages)\n", - "ai_msg" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(ai_msg.message.content)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Complete" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ai_msg = await llm.acomplete(user_msg.content)\n", - "ai_msg" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(ai_msg.text)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Async Streaming\n", - "\n", - "Not supported yet. Coming soon!" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "llamaindex_venv", - "language": "python", - "name": "llamaindex_venv" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/docs/docs/examples/llm/sambanovasystems.ipynb b/docs/docs/examples/llm/sambanovasystems.ipynb new file mode 100644 index 0000000000000..cda7e05c0cb0d --- /dev/null +++ b/docs/docs/examples/llm/sambanovasystems.ipynb @@ -0,0 +1,643 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SambaNova Systems\n", + "\n", + "In this notebook you will know how to install, setup and use the [SambaNova Cloud](https://cloud.sambanova.ai/) and [SambaStudio](https://docs.sambanova.ai/sambastudio/latest/sambastudio-intro.html) platforms. Take a look and try it yourself!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SambaNova Cloud\n", + "\n", + "[SambaNova Cloud](https://cloud.sambanova.ai/) is a high-performance inference service that delivers rapid and precise results. Customers can seamlessly leverage SambaNova technology to enhance their user experience by integrating FastAPI inference APIs with their applications. This service provides an easy-to-use REST interface for streaming the inference results. Users are able to customize the inference parameters and pass the ML model on to the service.\n", + "\n", + "## Setup\n", + "\n", + "To access SambaNova Cloud model you will need to create a [SambaNovaCloud](https://cloud.sambanova.ai/apis) account, get an API key, install the `llama-index-llms-sambanova` integration package, and install the `SSEClient` Package.\n", + "\n", + "```bash\n", + "pip install llama-index-llms-sambanovacloud\n", + "pip install sseclient-py\n", + "```\n", + "\n", + "### Credentials\n", + "\n", + "Get an API Key from [cloud.sambanova.ai](https://cloud.sambanova.ai/apis) and add it to your environment variables:\n", + "\n", + "``` bash\n", + "export SAMBANOVA_API_KEY=\"your-api-key-here\"\n", + "```\n", + "\n", + "If you don't have it in your env variables, you can also add it in the pop-up input text." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "if not os.getenv(\"SAMBANOVA_API_KEY\"):\n", + " os.environ[\"SAMBANOVA_API_KEY\"] = getpass.getpass(\n", + " \"Enter your SambaNova Cloud API key: \"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "The Llama-Index __SambaNova Cloud__ integration lives in the `llama-index-integrations` package, and it can be installed with the following commands:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install \"llama-index-llms-sambanova\"\n", + "%pip install sseclient-py" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiation\n", + "\n", + "Now we can instantiate our model object and generate chat completions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms.sambanovasystems import SambaNovaCloud\n", + "\n", + "llm = SambaNovaCloud(\n", + " model=\"Meta-Llama-3.1-70B-Instruct\",\n", + " max_tokens=1024,\n", + " temperature=0.7,\n", + " top_k=1,\n", + " top_p=0.01,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Invocation\n", + "\n", + "Given the following system and user messages, let's explore different ways of calling a SambaNova Cloud model. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.base.llms.types import (\n", + " ChatMessage,\n", + " MessageRole,\n", + ")\n", + "\n", + "system_msg = ChatMessage(\n", + " role=MessageRole.SYSTEM,\n", + " content=\"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n", + ")\n", + "user_msg = ChatMessage(role=MessageRole.USER, content=\"I love programming.\")\n", + "\n", + "messages = [\n", + " system_msg,\n", + " user_msg,\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Chat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ai_msg = llm.chat(messages)\n", + "ai_msg.message" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(ai_msg.message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Complete" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ai_msg = llm.complete(user_msg.content)\n", + "ai_msg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(ai_msg.text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Streaming" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Chat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ai_stream_msgs = []\n", + "for stream in llm.stream_chat(messages):\n", + " ai_stream_msgs.append(stream)\n", + "ai_stream_msgs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(ai_stream_msgs[-1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Complete" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ai_stream_msgs = []\n", + "for stream in llm.stream_complete(user_msg.content):\n", + " ai_stream_msgs.append(stream)\n", + "ai_stream_msgs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(ai_stream_msgs[-1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Async" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Chat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ai_msg = await llm.achat(messages)\n", + "ai_msg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(ai_msg.message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Complete" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ai_msg = await llm.acomplete(user_msg.content)\n", + "ai_msg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(ai_msg.text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Async Streaming\n", + "\n", + "Not supported yet. Coming soon!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SambaStudio\n", + "\n", + "[SambaStudio](https://docs.sambanova.ai/sambastudio/latest/sambastudio-intro.html) is a rich, GUI-based platform that provides the functionality to train, deploy, and manage models.\n", + "\n", + "## Setup\n", + "\n", + "To access SambaStudio models you will need to be a __SambaNova customer__, deploy an endpoint using the GUI or CLI, and use the URL and API Key to connect to the endpoint, as described in the [SambaStudio endpoint documentation](https://docs.sambanova.ai/sambastudio/latest/endpoints.html#_endpoint_api_keys). Then, install the `llama-index-llms-sambanova` integration package, and install the `SSEClient` Package.\n", + "\n", + "```bash\n", + "pip install llama-index-llms-sambanova\n", + "pip install sseclient-py\n", + "```\n", + "\n", + "### Credentials\n", + "\n", + "An endpoint must be deployed in SambaStudio to get the URL and API Key. Once they're available, include them to your environment variables:\n", + "\n", + "``` bash\n", + "export SAMBASTUDIO_URL=\"your-url-here\"\n", + "export SAMBASTUDIO_API_KEY=\"your-api-key-here\"\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "if not os.getenv(\"SAMBASTUDIO_URL\"):\n", + " os.environ[\"SAMBASTUDIO_URL\"] = getpass.getpass(\n", + " \"Enter your SambaStudio endpoint's URL: \"\n", + " )\n", + "\n", + "if not os.getenv(\"SAMBASTUDIO_API_KEY\"):\n", + " os.environ[\"SAMBASTUDIO_API_KEY\"] = getpass.getpass(\n", + " \"Enter your SambaStudio endpoint's API key: \"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "The Llama-Index __SambaStudio__ integration lives in the `llama-index-integrations` package, and it can be installed with the following commands:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install \"llama-index-llms-sambanova\"\n", + "%pip install sseclient-py" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiation\n", + "\n", + "Now we can instantiate our model object and generate chat completions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms.sambanovasystems import SambaStudio\n", + "\n", + "llm = SambaStudio(\n", + " model=\"Meta-Llama-3-70B-Instruct-4096\",\n", + " max_tokens=1024,\n", + " temperature=0.7,\n", + " top_k=1,\n", + " top_p=0.01,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Invocation\n", + "\n", + "Given the following system and user messages, let's explore different ways of calling a SambaNova Cloud model. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.base.llms.types import (\n", + " ChatMessage,\n", + " MessageRole,\n", + ")\n", + "\n", + "system_msg = ChatMessage(\n", + " role=MessageRole.SYSTEM,\n", + " content=\"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n", + ")\n", + "user_msg = ChatMessage(role=MessageRole.USER, content=\"I love programming.\")\n", + "\n", + "messages = [\n", + " system_msg,\n", + " user_msg,\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Chat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ai_msg = llm.chat(messages)\n", + "ai_msg.message" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(ai_msg.message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Complete" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ai_msg = llm.complete(user_msg.content)\n", + "ai_msg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(ai_msg.text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Streaming" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Chat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ai_stream_msgs = []\n", + "for stream in llm.stream_chat(messages):\n", + " ai_stream_msgs.append(stream)\n", + "ai_stream_msgs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(ai_stream_msgs[-1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Complete" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ai_stream_msgs = []\n", + "for stream in llm.stream_complete(user_msg.content):\n", + " ai_stream_msgs.append(stream)\n", + "ai_stream_msgs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(ai_stream_msgs[-1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Async" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Chat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ai_msg = await llm.achat(messages)\n", + "ai_msg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(ai_msg.message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Complete" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ai_msg = await llm.acomplete(user_msg.content)\n", + "ai_msg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(ai_msg.text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Async Streaming\n", + "\n", + "Not supported yet. Coming soon!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llamaindex_venv", + "language": "python", + "name": "llamaindex_venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 77b5712642fd1..f418248786b16 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -1049,7 +1049,7 @@ nav: - ./api_reference/llms/replicate.md - ./api_reference/llms/rungpt.md - ./api_reference/llms/sagemaker_endpoint.md - - ./api_reference/llms/sambanovacloud.md + - ./api_reference/llms/sambanovasystems.md - ./api_reference/llms/siliconflow.md - ./api_reference/llms/solar.md - ./api_reference/llms/text_generation_inference.md @@ -2319,7 +2319,7 @@ plugins: - ../llama-index-integrations/embeddings/llama-index-embeddings-siliconflow - ../llama-index-integrations/memory/llama-index-memory-mem0 - ../llama-index-integrations/postprocessor/llama-index-postprocessor-siliconflow-rerank - - ../llama-index-integrations/llms/llama-index-llms-sambanovacloud + - ../llama-index-integrations/llms/llama-index-llms-sambanovasystems - ../llama-index-integrations/embeddings/llama-index-embeddings-modelscope - ../llama-index-integrations/postprocessor/llama-index-postprocessor-colpali-rerank - ../llama-index-integrations/readers/llama-index-readers-gitbook diff --git a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/llama_index/llms/sambanovacloud/__init__.py b/llama-index-integrations/llms/llama-index-llms-sambanovacloud/llama_index/llms/sambanovacloud/__init__.py deleted file mode 100644 index f289193d7e649..0000000000000 --- a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/llama_index/llms/sambanovacloud/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from llama_index.llms.sambanovacloud.base import SambaNovaCloud - -__all__ = ["SambaNovaCloud"] diff --git a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/llama_index/llms/sambanovacloud/base.py b/llama-index-integrations/llms/llama-index-llms-sambanovacloud/llama_index/llms/sambanovacloud/base.py deleted file mode 100644 index 27a89277434d2..0000000000000 --- a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/llama_index/llms/sambanovacloud/base.py +++ /dev/null @@ -1,641 +0,0 @@ -import aiohttp - -from typing import Any, Dict, List, Optional, Iterator, Sequence, AsyncIterator - -import requests -from llama_index.core.llms.llm import LLM -from llama_index.core.llms.callbacks import ( - llm_chat_callback, - llm_completion_callback, -) -from llama_index.core.base.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, - MessageRole, -) -from llama_index.core.base.llms.generic_utils import ( - get_from_param_or_env, - chat_to_completion_decorator, - stream_chat_to_completion_decorator, - achat_to_completion_decorator, -) -from llama_index.core.bridge.pydantic import Field, SecretStr -import json - -from dotenv import load_dotenv - -load_dotenv() - - -def _convert_message_to_dict(message: ChatMessage) -> Dict[str, Any]: - """Converts a ChatMessage to a dictionary with Role / content. - - Args: - message: ChatMessage - - Returns: - messages_dict: role / content dict - """ - if isinstance(message, ChatMessage): - message_dict = {"role": message.role, "content": message.content} - else: - raise TypeError(f"Got unknown type {message}") - return message_dict - - -def _create_message_dicts(messages: Sequence[ChatMessage]) -> List[Dict[str, Any]]: - """Converts a list of ChatMessages to a list of dictionaries with Role / content. - - Args: - messages: list of ChatMessages - - Returns: - messages_dicts: list of role / content dicts - """ - return [_convert_message_to_dict(m) for m in messages] - - -class SambaNovaCloud(LLM): - """ - SambaNova Cloud model. - - Setup: - To use, you should have the environment variables: - ``SAMBANOVA_URL`` set with your SambaNova Cloud URL. - ``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key. - http://cloud.sambanova.ai/ - - Example: - .. code-block:: python - SambaNovaCloud( - sambanova_url = SambaNova cloud endpoint URL, - sambanova_api_key = set with your SambaNova cloud API key, - model = model name, - max_tokens = max number of tokens to generate, - temperature = model temperature, - top_p = model top p, - top_k = model top k, - stream_options = include usage to get generation metrics - ) - - Key init args — completion params: - model: str - The name of the model to use, e.g., Meta-Llama-3-70B-Instruct. - streaming: bool - Whether to use streaming handler when using non streaming methods - max_tokens: int - max tokens to generate - temperature: float - model temperature - top_p: float - model top p - top_k: int - model top k - stream_options: dict - stream options, include usage to get generation metrics - - Key init args — client params: - sambanova_url: str - SambaNova Cloud Url - sambanova_api_key: str - SambaNova Cloud api key - - Instantiate: - .. code-block:: python - - from llama_index.llms.sambanovacloud import SambaNovaCloud - - llm = SambaNovaCloud( - sambanova_url = SambaNova cloud endpoint URL, - sambanova_api_key = set with your SambaNova cloud API key, - model = model name, - max_tokens = max number of tokens to generate, - temperature = model temperature, - top_p = model top p, - top_k = model top k, - stream_options = include usage to get generation metrics - ) - Complete: - .. code-block:: python - prompt = "Tell me about Naruto Uzumaki in one sentence" - response = llm.complete(prompt) - - Chat: - .. code-block:: python - messages = [ - ChatMessage(role=MessageRole.SYSTEM, content=("You're a helpful assistant")), - ChatMessage(role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence") - ] - response = llm.chat(messages) - - Stream: - .. code-block:: python - prompt = "Tell me about Naruto Uzumaki in one sentence" - messages = [ - ChatMessage(role=MessageRole.SYSTEM, content=("You're a helpful assistant")), - ChatMessage(role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence") - ] - for chunk in llm.stream_complete(prompt): - print(chunk.text) - for chunk in llm.stream_chat(messages): - print(chunk.message.content) - - Async: - .. code-block:: python - prompt = "Tell me about Naruto Uzumaki in one sentence" - asyncio.run(llm.acomplete(prompt)) - - messages = [ - ChatMessage(role=MessageRole.SYSTEM, content=("You're a helpful assistant")), - ChatMessage(role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence") - ] - asyncio.run(llm.achat(chat_text_msgs)) - - Response metadata and usage - .. code-block:: python - - messages = [ - ChatMessage(role=MessageRole.SYSTEM, content=("You're a helpful assistant")), - ChatMessage(role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence") - ] - metadata_and_usage = llm.chat(messages).message.additional_kwargs - print(metadata_and_usage) - """ - - sambanova_url: str = Field(default_factory=str, description="SambaNova Cloud Url") - - sambanova_api_key: SecretStr = Field( - default_factory=str, description="SambaNova Cloud api key" - ) - - model: str = Field( - default="Meta-Llama-3.1-8B-Instruct", - description="The name of the model", - ) - - streaming: bool = Field( - default=False, - description="Whether to use streaming handler when using non streaming methods", - ) - - max_tokens: int = Field(default=1024, description="max tokens to generate") - - temperature: float = Field(default=0.7, description="model temperature") - - top_p: Optional[float] = Field(default=None, description="model top p") - - top_k: Optional[int] = Field(default=None, description="model top k") - - stream_options: dict = Field( - default={"include_usage": True}, - description="stream options, include usage to get generation metrics", - ) - - @classmethod - def class_name(cls) -> str: - return "SambaNovaCloud" - - @property - def metadata(self) -> LLMMetadata: - return LLMMetadata( - context_window=None, - num_output=self.max_tokens, - is_chat_model=True, - model_name=self.model, - ) - - def __init__(self, **kwargs: Any) -> None: - """Init and validate environment variables.""" - kwargs["sambanova_url"] = get_from_param_or_env( - "url", - kwargs.get("sambanova_url"), - "SAMBANOVA_URL", - default="https://api.sambanova.ai/v1/chat/completions", - ) - kwargs["sambanova_api_key"] = get_from_param_or_env( - "api_key", kwargs.get("sambanova_api_key"), "SAMBANOVA_API_KEY" - ) - super().__init__(**kwargs) - - def _handle_request( - self, messages_dicts: List[Dict], stop: Optional[List[str]] = None - ) -> Dict[str, Any]: - """ - Performs a post request to the LLM API. - - Args: - messages_dicts: List of role / content dicts to use as input. - stop: list of stop tokens - - Returns: - A response dict. - """ - data = { - "messages": messages_dicts, - "max_tokens": self.max_tokens, - "stop": stop, - "model": self.model, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - } - http_session = requests.Session() - response = http_session.post( - self.sambanova_url, - headers={ - "Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}", - "Content-Type": "application/json", - }, - json=data, - ) - if response.status_code != 200: - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}.", - f"{response.text}.", - ) - response_dict = response.json() - if response_dict.get("error"): - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}.", - f"{response_dict}.", - ) - return response_dict - - async def _handle_request_async( - self, messages_dicts: List[Dict], stop: Optional[List[str]] = None - ) -> Dict[str, Any]: - """ - Performs a async post request to the LLM API. - - Args: - messages_dicts: List of role / content dicts to use as input. - stop: list of stop tokens - - Returns: - A response dict. - """ - data = { - "messages": messages_dicts, - "max_tokens": self.max_tokens, - "stop": stop, - "model": self.model, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - } - - async with aiohttp.ClientSession() as session: - async with session.post( - self.sambanova_url, - headers={ - "Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}", - "Content-Type": "application/json", - }, - json=data, - ) as response: - if response.status != 200: - raise RuntimeError( - f"Sambanova /complete call failed with status code {response.status}.", - f"{await response.text()}.", - ) - response_dict = await response.json() - if response_dict.get("error"): - raise RuntimeError( - f"Sambanova /complete call failed with status code {response.status}.", - f"{response_dict}.", - ) - return response_dict - - def _handle_streaming_request( - self, messages_dicts: List[Dict], stop: Optional[List[str]] = None - ) -> Iterator[Dict]: - """ - Performs an streaming post request to the LLM API. - - Args: - messages_dicts: List of role / content dicts to use as input. - stop: list of stop tokens - - Yields: - An iterator of response dicts. - """ - try: - import sseclient - except ImportError: - raise ImportError( - "could not import sseclient library" - "Please install it with `pip install sseclient-py`." - ) - data = { - "messages": messages_dicts, - "max_tokens": self.max_tokens, - "stop": stop, - "model": self.model, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "stream": True, - "stream_options": self.stream_options, - } - http_session = requests.Session() - response = http_session.post( - self.sambanova_url, - headers={ - "Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}", - "Content-Type": "application/json", - }, - json=data, - stream=True, - ) - - client = sseclient.SSEClient(response) - - if response.status_code != 200: - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}." - f"{response.text}." - ) - - for event in client.events(): - if event.event == "error_event": - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}." - f"{event.data}." - ) - - try: - # check if the response is a final event - # in that case event data response is '[DONE]' - if event.data != "[DONE]": - if isinstance(event.data, str): - data = json.loads(event.data) - else: - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}." - f"{event.data}." - ) - if data.get("error"): - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}." - f"{event.data}." - ) - yield data - except Exception as e: - raise RuntimeError( - f"Error getting content chunk raw streamed response: {e}" - f"data: {event.data}" - ) - - async def _handle_streaming_request_async( - self, messages_dicts: List[Dict], stop: Optional[List[str]] = None - ) -> AsyncIterator[Dict]: - """ - Performs an async streaming post request to the LLM API. - - Args: - messages_dicts: List of role / content dicts to use as input. - stop: list of stop tokens - - Yields: - An iterator of response dicts. - """ - data = { - "messages": messages_dicts, - "max_tokens": self.max_tokens, - "stop": stop, - "model": self.model, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "stream": True, - "stream_options": self.stream_options, - } - - async with aiohttp.ClientSession() as session: - async with session.post( - self.sambanova_url, - headers={ - "Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}", - "Content-Type": "application/json", - }, - json=data, - ) as response: - if response.status != 200: - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status}. {await response.text()}" - ) - - async for line in response.content: - if line: - event = line.decode("utf-8").strip() - - if event.startswith("data:"): - event = event[len("data:") :].strip() - if event == "[DONE]": - break - elif len(event) == 0: - continue - - try: - data = json.loads(event) - if data.get("error"): - raise RuntimeError( - f'Sambanova /complete call failed: {data["error"]}' - ) - yield data - except json.JSONDecodeError: - raise RuntimeError( - f"Sambanova /complete call failed to decode response: {event}" - ) - except Exception as e: - raise RuntimeError( - f"Error processing response: {e} data: {event}" - ) - - @llm_chat_callback() - def chat( - self, - messages: Sequence[ChatMessage], - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> ChatResponse: - """ - Calls the chat implementation of the SambaNovaCloud model. - - Args: - messages: the prompt composed of a list of messages. - stop: a list of strings on which the model should stop generating. - If generation stops due to a stop token, the stop token itself - SHOULD BE INCLUDED as part of the output. This is not enforced - across models right now, but it's a good practice to follow since - it makes it much easier to parse the output of the model - downstream and understand why generation stopped. - - Returns: - ChatResponse with model generation - """ - messages_dicts = _create_message_dicts(messages) - - response = self._handle_request(messages_dicts, stop) - message = ChatMessage( - role=MessageRole.ASSISTANT, - content=response["choices"][0]["message"]["content"], - additional_kwargs={ - "id": response["id"], - "finish_reason": response["choices"][0]["finish_reason"], - "usage": response.get("usage"), - "model_name": response["model"], - "system_fingerprint": response["system_fingerprint"], - "created": response["created"], - }, - ) - return ChatResponse(message=message) - - @llm_chat_callback() - def stream_chat( - self, - messages: Sequence[ChatMessage], - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> ChatResponseGen: - """ - Streams the chat output of the SambaNovaCloud model. - - Args: - messages: the prompt composed of a list of messages. - stop: a list of strings on which the model should stop generating. - If generation stops due to a stop token, the stop token itself - SHOULD BE INCLUDED as part of the output. This is not enforced - across models right now, but it's a good practice to follow since - it makes it much easier to parse the output of the model - downstream and understand why generation stopped. - - Yields: - ChatResponseGen with model partial generation - """ - messages_dicts = _create_message_dicts(messages) - - finish_reason = None - content = "" - role = MessageRole.ASSISTANT - - for partial_response in self._handle_streaming_request(messages_dicts, stop): - if len(partial_response["choices"]) > 0: - content_delta = partial_response["choices"][0]["delta"]["content"] - content += content_delta - additional_kwargs = { - "id": partial_response["id"], - "finish_reason": partial_response["choices"][0].get( - "finish_reason" - ), - } - else: - additional_kwargs = { - "id": partial_response["id"], - "finish_reason": finish_reason, - "usage": partial_response.get("usage"), - "model_name": partial_response["model"], - "system_fingerprint": partial_response["system_fingerprint"], - "created": partial_response["created"], - } - - # yield chunk - yield ChatResponse( - message=ChatMessage( - role=role, content=content, additional_kwargs=additional_kwargs - ), - delta=content_delta, - raw=partial_response, - ) - - @llm_completion_callback() - def complete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponse: - complete_fn = chat_to_completion_decorator(self.chat) - return complete_fn(prompt, **kwargs) - - @llm_completion_callback() - def stream_complete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponseGen: - stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat) - return stream_complete_fn(prompt, **kwargs) - - ### Async ### - @llm_chat_callback() - async def achat( - self, - messages: Sequence[ChatMessage], - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> ChatResponse: - """ - Calls the async chat implementation of the SambaNovaCloud model. - - Args: - messages: the prompt composed of a list of messages. - stop: a list of strings on which the model should stop generating. - If generation stops due to a stop token, the stop token itself - SHOULD BE INCLUDED as part of the output. This is not enforced - across models right now, but it's a good practice to follow since - it makes it much easier to parse the output of the model - downstream and understand why generation stopped. - - Returns: - ChatResponse with async model generation - """ - messages_dicts = _create_message_dicts(messages) - response = await self._handle_request_async(messages_dicts, stop) - message = ChatMessage( - role=MessageRole.ASSISTANT, - content=response["choices"][0]["message"]["content"], - additional_kwargs={ - "id": response["id"], - "finish_reason": response["choices"][0]["finish_reason"], - "usage": response.get("usage"), - "model_name": response["model"], - "system_fingerprint": response["system_fingerprint"], - "created": response["created"], - }, - ) - return ChatResponse(message=message) - - @llm_chat_callback() - async def astream_chat( - self, - messages: Sequence[ChatMessage], - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> ChatResponseAsyncGen: - raise NotImplementedError( - "SambaNovaCloud does not currently support async streaming." - ) - - @llm_completion_callback() - async def acomplete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponse: - acomplete_fn = achat_to_completion_decorator(self.achat) - return await acomplete_fn(prompt, **kwargs) - - @llm_completion_callback() - def astream_complete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponseAsyncGen: - raise NotImplementedError( - "SambaNovaCloud does not currently support async streaming." - ) diff --git a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/.gitignore b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/.gitignore similarity index 100% rename from llama-index-integrations/llms/llama-index-llms-sambanovacloud/.gitignore rename to llama-index-integrations/llms/llama-index-llms-sambanovasystems/.gitignore diff --git a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/BUILD b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/BUILD similarity index 100% rename from llama-index-integrations/llms/llama-index-llms-sambanovacloud/BUILD rename to llama-index-integrations/llms/llama-index-llms-sambanovasystems/BUILD diff --git a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/Makefile b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/Makefile similarity index 100% rename from llama-index-integrations/llms/llama-index-llms-sambanovacloud/Makefile rename to llama-index-integrations/llms/llama-index-llms-sambanovasystems/Makefile diff --git a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/README.md b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/README.md similarity index 55% rename from llama-index-integrations/llms/llama-index-llms-sambanovacloud/README.md rename to llama-index-integrations/llms/llama-index-llms-sambanovasystems/README.md index db8cf76156fa8..8168fcaa96f42 100644 --- a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/README.md +++ b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/README.md @@ -1,6 +1,6 @@ # LlamaIndex LLM Integration: SambaNova LLM -SambaNovaLLM is a custom LLM (Language Model) interface that allows you to interact with AI models hosted on SambaNova's offerings - SambaNova Cloud and SambaStudio +SambaNova Systems LLMs are custom LLMs (Language Models) interfaces that allow you to interact with AI models hosted on SambaNova's offerings - SambaNova Cloud and SambaStudio ## Key Features: @@ -13,13 +13,13 @@ SambaNovaLLM is a custom LLM (Language Model) interface that allows you to inter ## Installation ```bash -pip install llama-index-llms-sambanovacloud +pip install llama-index-llms-sambanovasystems ``` ## Usage ```python -from llama_index.llms.sambanovacloud import SambaNovaCloud +from llama_index.llms.sambanovasystems import SambaNovaCloud SambaNovaCloud( sambanova_url="SambaNova cloud endpoint URL", @@ -31,9 +31,11 @@ SambaNovaCloud( ## Usage ```python -SambaNovaCloud( - sambanova_url="SambaNova cloud endpoint URL", - sambanova_api_key="set with your SambaNova cloud API key", +from llama_index.llms.sambanovasystems import SambaStudio + +SambaStudio( + sambastudio_url="SambaStudio endpoint URL", + sambastudio_api_key="set with your SambaStudio endppoint API key", model="model name", ) ``` diff --git a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/llama_index/llms/sambanovacloud/BUILD b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/llama_index/llms/sambanovasystems/BUILD similarity index 100% rename from llama-index-integrations/llms/llama-index-llms-sambanovacloud/llama_index/llms/sambanovacloud/BUILD rename to llama-index-integrations/llms/llama-index-llms-sambanovasystems/llama_index/llms/sambanovasystems/BUILD diff --git a/llama-index-integrations/llms/llama-index-llms-sambanovasystems/llama_index/llms/sambanovasystems/__init__.py b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/llama_index/llms/sambanovasystems/__init__.py new file mode 100644 index 0000000000000..63cd1bbdb5dd6 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/llama_index/llms/sambanovasystems/__init__.py @@ -0,0 +1,3 @@ +from llama_index.llms.sambanovasystems.base import SambaNovaCloud, SambaStudio + +__all__ = ["SambaNovaCloud", "SambaStudio"] diff --git a/llama-index-integrations/llms/llama-index-llms-sambanovasystems/llama_index/llms/sambanovasystems/base.py b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/llama_index/llms/sambanovasystems/base.py new file mode 100644 index 0000000000000..60b551d81e032 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/llama_index/llms/sambanovasystems/base.py @@ -0,0 +1,1570 @@ +import aiohttp + +from typing import Any, Dict, List, Optional, Iterator, Sequence, AsyncIterator, Tuple + +import requests +from llama_index.core.llms.llm import LLM +from llama_index.core.llms.callbacks import ( + llm_chat_callback, + llm_completion_callback, +) +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) +from llama_index.core.base.llms.generic_utils import ( + get_from_param_or_env, + chat_to_completion_decorator, + stream_chat_to_completion_decorator, + achat_to_completion_decorator, +) +from llama_index.core.bridge.pydantic import Field, SecretStr +import json +from requests import Response + +from dotenv import load_dotenv + +load_dotenv() + + +def _convert_message_to_dict(message: ChatMessage) -> Dict[str, Any]: + """Converts a ChatMessage to a dictionary with Role / content. + + Args: + message: ChatMessage + + Returns: + messages_dict: role / content dict + """ + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + else: + raise TypeError(f"Got unknown type {message}") + return message_dict + + +def _create_message_dicts(messages: Sequence[ChatMessage]) -> List[Dict[str, Any]]: + """Converts a list of ChatMessages to a list of dictionaries with Role / content. + + Args: + messages: list of ChatMessages + + Returns: + messages_dicts: list of role / content dicts + """ + return [_convert_message_to_dict(m) for m in messages] + + +class SambaNovaCloud(LLM): + """ + SambaNova Cloud model. + + Setup: + To use, you should have the environment variables: + ``SAMBANOVA_URL`` set with your SambaNova Cloud URL. + ``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key. + http://cloud.sambanova.ai/ + + Example: + .. code-block:: python + SambaNovaCloud( + sambanova_url = SambaNova cloud endpoint URL, + sambanova_api_key = set with your SambaNova cloud API key, + model = model name, + max_tokens = max number of tokens to generate, + temperature = model temperature, + top_p = model top p, + top_k = model top k, + stream_options = include usage to get generation metrics + ) + + Key init args — completion params: + model: str + The name of the model to use, e.g., Meta-Llama-3-70B-Instruct. + streaming: bool + Whether to use streaming handler when using non streaming methods + max_tokens: int + max tokens to generate + temperature: float + model temperature + top_p: float + model top p + top_k: int + model top k + stream_options: dict + stream options, include usage to get generation metrics + + Key init args — client params: + sambanova_url: str + SambaNova Cloud Url + sambanova_api_key: str + SambaNova Cloud api key + + Instantiate: + .. code-block:: python + + from llama_index.llms.sambanovacloud import SambaNovaCloud + + llm = SambaNovaCloud( + sambanova_url = SambaNova cloud endpoint URL, + sambanova_api_key = set with your SambaNova cloud API key, + model = model name, + max_tokens = max number of tokens to generate, + temperature = model temperature, + top_p = model top p, + top_k = model top k, + stream_options = include usage to get generation metrics + ) + Complete: + .. code-block:: python + prompt = "Tell me about Naruto Uzumaki in one sentence" + response = llm.complete(prompt) + + Chat: + .. code-block:: python + messages = [ + ChatMessage(role=MessageRole.SYSTEM, content=("You're a helpful assistant")), + ChatMessage(role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence") + ] + response = llm.chat(messages) + + Stream: + .. code-block:: python + prompt = "Tell me about Naruto Uzumaki in one sentence" + messages = [ + ChatMessage(role=MessageRole.SYSTEM, content=("You're a helpful assistant")), + ChatMessage(role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence") + ] + for chunk in llm.stream_complete(prompt): + print(chunk.text) + for chunk in llm.stream_chat(messages): + print(chunk.message.content) + + Async: + .. code-block:: python + prompt = "Tell me about Naruto Uzumaki in one sentence" + asyncio.run(llm.acomplete(prompt)) + + messages = [ + ChatMessage(role=MessageRole.SYSTEM, content=("You're a helpful assistant")), + ChatMessage(role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence") + ] + asyncio.run(llm.achat(chat_text_msgs)) + + Response metadata and usage + .. code-block:: python + + messages = [ + ChatMessage(role=MessageRole.SYSTEM, content=("You're a helpful assistant")), + ChatMessage(role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence") + ] + metadata_and_usage = llm.chat(messages).message.additional_kwargs + print(metadata_and_usage) + """ + + sambanova_url: str = Field(default_factory=str, description="SambaNova Cloud Url") + + sambanova_api_key: SecretStr = Field( + default_factory=str, description="SambaNova Cloud api key" + ) + + model: str = Field( + default="Meta-Llama-3.1-8B-Instruct", + description="The name of the model", + ) + + streaming: bool = Field( + default=False, + description="Whether to use streaming handler when using non streaming methods", + ) + + max_tokens: int = Field(default=1024, description="max tokens to generate") + + temperature: float = Field(default=0.7, description="model temperature") + + top_p: Optional[float] = Field(default=None, description="model top p") + + top_k: Optional[int] = Field(default=None, description="model top k") + + stream_options: dict = Field( + default={"include_usage": True}, + description="stream options, include usage to get generation metrics", + ) + + @classmethod + def class_name(cls) -> str: + return "SambaNovaCloud" + + @property + def metadata(self) -> LLMMetadata: + return LLMMetadata( + context_window=None, + num_output=self.max_tokens, + is_chat_model=True, + model_name=self.model, + ) + + def __init__(self, **kwargs: Any) -> None: + """Init and validate environment variables.""" + kwargs["sambanova_url"] = get_from_param_or_env( + "url", + kwargs.get("sambanova_url"), + "SAMBANOVA_URL", + default="https://api.sambanova.ai/v1/chat/completions", + ) + kwargs["sambanova_api_key"] = get_from_param_or_env( + "api_key", kwargs.get("sambanova_api_key"), "SAMBANOVA_API_KEY" + ) + super().__init__(**kwargs) + + def _handle_request( + self, messages_dicts: List[Dict], stop: Optional[List[str]] = None + ) -> Dict[str, Any]: + """ + Performs a post request to the LLM API. + + Args: + messages_dicts: List of role / content dicts to use as input. + stop: list of stop tokens + + Returns: + A response dict. + """ + data = { + "messages": messages_dicts, + "max_tokens": self.max_tokens, + "stop": stop, + "model": self.model, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + } + http_session = requests.Session() + response = http_session.post( + self.sambanova_url, + headers={ + "Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}", + "Content-Type": "application/json", + }, + json=data, + ) + if response.status_code != 200: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}.", + f"{response.text}.", + ) + response_dict = response.json() + if response_dict.get("error"): + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}.", + f"{response_dict}.", + ) + return response_dict + + async def _handle_request_async( + self, messages_dicts: List[Dict], stop: Optional[List[str]] = None + ) -> Dict[str, Any]: + """ + Performs a async post request to the LLM API. + + Args: + messages_dicts: List of role / content dicts to use as input. + stop: list of stop tokens + + Returns: + A response dict. + """ + data = { + "messages": messages_dicts, + "max_tokens": self.max_tokens, + "stop": stop, + "model": self.model, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + self.sambanova_url, + headers={ + "Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}", + "Content-Type": "application/json", + }, + json=data, + ) as response: + if response.status != 200: + raise RuntimeError( + f"Sambanova /complete call failed with status code {response.status}.", + f"{await response.text()}.", + ) + response_dict = await response.json() + if response_dict.get("error"): + raise RuntimeError( + f"Sambanova /complete call failed with status code {response.status}.", + f"{response_dict}.", + ) + return response_dict + + def _handle_streaming_request( + self, messages_dicts: List[Dict], stop: Optional[List[str]] = None + ) -> Iterator[Dict]: + """ + Performs an streaming post request to the LLM API. + + Args: + messages_dicts: List of role / content dicts to use as input. + stop: list of stop tokens + + Yields: + An iterator of response dicts. + """ + try: + import sseclient + except ImportError: + raise ImportError( + "could not import sseclient library" + "Please install it with `pip install sseclient-py`." + ) + data = { + "messages": messages_dicts, + "max_tokens": self.max_tokens, + "stop": stop, + "model": self.model, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "stream": True, + "stream_options": self.stream_options, + } + http_session = requests.Session() + response = http_session.post( + self.sambanova_url, + headers={ + "Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}", + "Content-Type": "application/json", + }, + json=data, + stream=True, + ) + + client = sseclient.SSEClient(response) + + if response.status_code != 200: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}." + f"{response.text}." + ) + + for event in client.events(): + if event.event == "error_event": + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}." + f"{event.data}." + ) + + try: + # check if the response is a final event + # in that case event data response is '[DONE]' + if event.data != "[DONE]": + if isinstance(event.data, str): + data = json.loads(event.data) + else: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}." + f"{event.data}." + ) + if data.get("error"): + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}." + f"{event.data}." + ) + yield data + except Exception as e: + raise RuntimeError( + f"Error getting content chunk raw streamed response: {e}" + f"data: {event.data}" + ) + + async def _handle_streaming_request_async( + self, messages_dicts: List[Dict], stop: Optional[List[str]] = None + ) -> AsyncIterator[Dict]: + """ + Performs an async streaming post request to the LLM API. + + Args: + messages_dicts: List of role / content dicts to use as input. + stop: list of stop tokens + + Yields: + An iterator of response dicts. + """ + data = { + "messages": messages_dicts, + "max_tokens": self.max_tokens, + "stop": stop, + "model": self.model, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "stream": True, + "stream_options": self.stream_options, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + self.sambanova_url, + headers={ + "Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}", + "Content-Type": "application/json", + }, + json=data, + ) as response: + if response.status != 200: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status}. {await response.text()}" + ) + + async for line in response.content: + if line: + event = line.decode("utf-8").strip() + + if event.startswith("data:"): + event = event[len("data:") :].strip() + if event == "[DONE]": + break + elif len(event) == 0: + continue + + try: + data = json.loads(event) + if data.get("error"): + raise RuntimeError( + f'Sambanova /complete call failed: {data["error"]}' + ) + yield data + except json.JSONDecodeError: + raise RuntimeError( + f"Sambanova /complete call failed to decode response: {event}" + ) + except Exception as e: + raise RuntimeError( + f"Error processing response: {e} data: {event}" + ) + + @llm_chat_callback() + def chat( + self, + messages: Sequence[ChatMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> ChatResponse: + """ + Calls the chat implementation of the SambaNovaCloud model. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + If generation stops due to a stop token, the stop token itself + SHOULD BE INCLUDED as part of the output. This is not enforced + across models right now, but it's a good practice to follow since + it makes it much easier to parse the output of the model + downstream and understand why generation stopped. + + Returns: + ChatResponse with model generation + """ + messages_dicts = _create_message_dicts(messages) + + response = self._handle_request(messages_dicts, stop) + message = ChatMessage( + role=MessageRole.ASSISTANT, + content=response["choices"][0]["message"]["content"], + additional_kwargs={ + "id": response["id"], + "finish_reason": response["choices"][0]["finish_reason"], + "usage": response.get("usage"), + "model_name": response["model"], + "system_fingerprint": response["system_fingerprint"], + "created": response["created"], + }, + ) + return ChatResponse(message=message) + + @llm_chat_callback() + def stream_chat( + self, + messages: Sequence[ChatMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> ChatResponseGen: + """ + Streams the chat output of the SambaNovaCloud model. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + If generation stops due to a stop token, the stop token itself + SHOULD BE INCLUDED as part of the output. This is not enforced + across models right now, but it's a good practice to follow since + it makes it much easier to parse the output of the model + downstream and understand why generation stopped. + + Yields: + ChatResponseGen with model partial generation + """ + messages_dicts = _create_message_dicts(messages) + + finish_reason = None + content = "" + role = MessageRole.ASSISTANT + + for partial_response in self._handle_streaming_request(messages_dicts, stop): + if len(partial_response["choices"]) > 0: + content_delta = partial_response["choices"][0]["delta"]["content"] + content += content_delta + additional_kwargs = { + "id": partial_response["id"], + "finish_reason": partial_response["choices"][0].get( + "finish_reason" + ), + } + else: + additional_kwargs = { + "id": partial_response["id"], + "finish_reason": finish_reason, + "usage": partial_response.get("usage"), + "model_name": partial_response["model"], + "system_fingerprint": partial_response["system_fingerprint"], + "created": partial_response["created"], + } + + # yield chunk + yield ChatResponse( + message=ChatMessage( + role=role, content=content, additional_kwargs=additional_kwargs + ), + delta=content_delta, + raw=partial_response, + ) + + @llm_completion_callback() + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + complete_fn = chat_to_completion_decorator(self.chat) + return complete_fn(prompt, **kwargs) + + @llm_completion_callback() + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat) + return stream_complete_fn(prompt, **kwargs) + + ### Async ### + @llm_chat_callback() + async def achat( + self, + messages: Sequence[ChatMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> ChatResponse: + """ + Calls the async chat implementation of the SambaNovaCloud model. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + If generation stops due to a stop token, the stop token itself + SHOULD BE INCLUDED as part of the output. This is not enforced + across models right now, but it's a good practice to follow since + it makes it much easier to parse the output of the model + downstream and understand why generation stopped. + + Returns: + ChatResponse with async model generation + """ + messages_dicts = _create_message_dicts(messages) + response = await self._handle_request_async(messages_dicts, stop) + message = ChatMessage( + role=MessageRole.ASSISTANT, + content=response["choices"][0]["message"]["content"], + additional_kwargs={ + "id": response["id"], + "finish_reason": response["choices"][0]["finish_reason"], + "usage": response.get("usage"), + "model_name": response["model"], + "system_fingerprint": response["system_fingerprint"], + "created": response["created"], + }, + ) + return ChatResponse(message=message) + + @llm_chat_callback() + async def astream_chat( + self, + messages: Sequence[ChatMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> ChatResponseAsyncGen: + raise NotImplementedError( + "SambaNovaCloud does not currently support async streaming." + ) + + @llm_completion_callback() + async def acomplete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + acomplete_fn = achat_to_completion_decorator(self.achat) + return await acomplete_fn(prompt, **kwargs) + + @llm_completion_callback() + def astream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseAsyncGen: + raise NotImplementedError( + "SambaNovaCloud does not currently support async streaming." + ) + + +class SambaStudio(LLM): + """ + SambaStudio model. + + Setup: + To use, you should have the environment variables: + ``SAMBASTUDIO_URL`` set with your SambaStudio deployed endpoint URL. + ``SAMBASTUDIO_API_KEY`` set with your SambaStudio deployed endpoint Key. + https://docs.sambanova.ai/sambastudio/latest/index.html + Example: + .. code-block:: python + SambaStudio( + sambastudio_url = set with your SambaStudio deployed endpoint URL, + sambastudio_api_key = set with your SambaStudio deployed endpoint Key. + model = model or expert name (set for CoE endpoints), + max_tokens = max number of tokens to generate, + temperature = model temperature, + top_p = model top p, + top_k = model top k, + do_sample = whether to do sample + process_prompt = whether to process prompt + (set for CoE generic v1 and v2 endpoints) + stream_options = include usage to get generation metrics + special_tokens = start, start_role, end_role, end special tokens + (set for CoE generic v1 and v2 endpoints when process prompt + set to false or for StandAlone v1 and v2 endpoints) + model_kwargs: Optional = Extra Key word arguments to pass to the model. + ) + + Key init args — completion params: + model: str + The name of the model to use, e.g., Meta-Llama-3-70B-Instruct-4096 + (set for CoE endpoints). + streaming: bool + Whether to use streaming + max_tokens: inthandler when using non streaming methods + max tokens to generate + temperature: float + model temperature + top_p: float + model top p + top_k: int + model top k + do_sample: bool + whether to do sample + process_prompt: + whether to process prompt (set for CoE generic v1 and v2 endpoints) + stream_options: dict + stream options, include usage to get generation metrics + special_tokens: dict + start, start_role, end_role and end special tokens + (set for CoE generic v1 and v2 endpoints when process prompt set to false + or for StandAlone v1 and v2 endpoints) default to llama3 special tokens + model_kwargs: dict + Extra Key word arguments to pass to the model. + + Key init args — client params: + sambastudio_url: str + SambaStudio endpoint Url + sambastudio_api_key: str + SambaStudio endpoint api key + + Instantiate: + .. code-block:: python + + from llama_index.llms.sambanova import SambaStudio + + llm = SambaStudio=( + sambastudio_url = set with your SambaStudio deployed endpoint URL, + sambastudio_api_key = set with your SambaStudio deployed endpoint Key. + model = model or expert name (set for CoE endpoints), + max_tokens = max number of tokens to generate, + temperature = model temperature, + top_p = model top p, + top_k = model top k, + do_sample = whether to do sample + process_prompt = whether to process prompt + (set for CoE generic v1 and v2 endpoints) + stream_options = include usage to get generation metrics + special_tokens = start, start_role, end_role, and special tokens + (set for CoE generic v1 and v2 endpoints when process prompt + set to false or for StandAlone v1 and v2 endpoints) + model_kwargs: Optional = Extra Key word arguments to pass to the model. + ) + Complete: + .. code-block:: python + prompt = "Tell me about Naruto Uzumaki in one sentence" + response = llm.complete(prompt) + + Chat: + .. code-block:: python + messages = [ + ChatMessage(role=MessageRole.SYSTEM, content=("You're a helpful assistant")), + ChatMessage(role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence") + ] + response = llm.chat(messages) + + Stream: + .. code-block:: python + prompt = "Tell me about Naruto Uzumaki in one sentence" + messages = [ + ChatMessage(role=MessageRole.SYSTEM, content=("You're a helpful assistant")), + ChatMessage(role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence") + ] + for chunk in llm.stream_complete(prompt): + print(chunk.text) + for chunk in llm.stream_chat(messages): + print(chunk.message.content) + + Async: + .. code-block:: python + prompt = "Tell me about Naruto Uzumaki in one sentence" + asyncio.run(llm.acomplete(prompt)) + + messages = [ + ChatMessage(role=MessageRole.SYSTEM, content=("You're a helpful assistant")), + ChatMessage(role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence") + ] + asyncio.run(llm.achat(chat_text_msgs)) + + Response metadata and usage + .. code-block:: python + + messages = [ + ChatMessage(role=MessageRole.SYSTEM, content=("You're a helpful assistant")), + ChatMessage(role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence") + ] + metadata_and_usage = llm.chat(messages).message.additional_kwargs + print(metadata_and_usage) + """ + + sambastudio_url: str = Field(default_factory=str, description="SambaStudio Url") + + sambastudio_api_key: SecretStr = Field( + default_factory=str, description="SambaStudio api key" + ) + + base_url: str = Field( + default_factory=str, exclude=True, description="SambaStudio non streaming Url" + ) + + streaming_url: str = Field( + default_factory=str, exclude=True, description="SambaStudio streaming Url" + ) + + model: Optional[str] = Field( + default_factory=Optional[str], + description="The name of the model or expert to use (for CoE endpoints)", + ) + + streaming: bool = Field( + default=False, + description="Whether to use streaming handler when using non streaming methods", + ) + + max_tokens: int = Field(default=1024, description="max tokens to generate") + + temperature: Optional[float] = Field(default=0.7, description="model temperature") + + top_p: Optional[float] = Field(default=None, description="model top p") + + top_k: Optional[int] = Field(default=None, description="model top k") + + do_sample: Optional[bool] = Field( + default=None, description="whether to do sampling" + ) + + process_prompt: Optional[bool] = Field( + default=True, + description="whether process prompt (for CoE generic v1 and v2 endpoints)", + ) + + stream_options: dict = Field( + default={"include_usage": True}, + description="stream options, include usage to get generation metrics", + ) + + special_tokens: dict = Field( + default={ + "start": "<|begin_of_text|>", + "start_role": "<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>", + "end_role": "<|eot_id|>", + "end": "<|start_header_id|>assistant<|end_header_id|>\n", + }, + description="start, start_role, end_role and end special tokens (set for CoE generic v1 and v2 endpoints when process prompt set to false or for StandAlone v1 and v2 endpoints) default to llama3 special tokens", + ) + + model_kwargs: Optional[Dict[str, Any]] = Field( + default=None, description="Key word arguments to pass to the model." + ) + + @classmethod + def class_name(cls) -> str: + return "SambaStudio" + + @property + def metadata(self) -> LLMMetadata: + return LLMMetadata( + context_window=None, + num_output=self.max_tokens, + is_chat_model=True, + model_name=self.model, + ) + + def __init__(self, **kwargs: Any) -> None: + """Init and validate environment variables.""" + kwargs["sambastudio_url"] = get_from_param_or_env( + "url", kwargs.get("sambastudio_url"), "SAMBASTUDIO_URL" + ) + kwargs["sambastudio_api_key"] = get_from_param_or_env( + "api_key", kwargs.get("sambastudio_api_key"), "SAMBASTUDIO_API_KEY" + ) + kwargs["base_url"], kwargs["streaming_url"] = self._get_sambastudio_urls( + kwargs["sambastudio_url"] + ) + super().__init__(**kwargs) + + def _messages_to_string(self, messages: Sequence[ChatMessage]) -> str: + """Convert a sequence of ChatMessages to: + - dumped json string with Role / content dict structure when process_prompt is true, + - string with special tokens if process_prompt is false for generic V1 and V2 endpoints. + + Args: + messages: sequence of ChatMessages + + Returns: + str: string to send as model input depending on process_prompt param + """ + if self.process_prompt: + messages_dict: Dict[str, Any] = { + "conversation_id": "sambaverse-conversation-id", + "messages": [], + } + for message in messages: + messages_dict["messages"].append( + { + "role": message.role, + "content": message.content, + } + ) + messages_string = json.dumps(messages_dict) + else: + messages_string = self.special_tokens["start"] + for message in messages: + messages_string += self.special_tokens["start_role"].format( + role=self._get_role(message) + ) + messages_string += f" {message.content} " + messages_string += self.special_tokens["end_role"] + messages_string += self.special_tokens["end"] + + return messages_string + + def _get_sambastudio_urls(self, url: str) -> Tuple[str, str]: + """Get streaming and non streaming URLs from the given URL. + + Args: + url: string with sambastudio base or streaming endpoint url + + Returns: + base_url: string with url to do non streaming calls + streaming_url: string with url to do streaming calls + """ + if "openai" in url: + base_url = url + stream_url = url + else: + if "stream" in url: + base_url = url.replace("stream/", "") + stream_url = url + else: + base_url = url + if "generic" in url: + stream_url = "generic/stream".join(url.split("generic")) + else: + raise ValueError("Unsupported URL") + return base_url, stream_url + + def _handle_request( + self, + messages: Sequence[ChatMessage], + stop: Optional[List[str]] = None, + streaming: Optional[bool] = False, + ) -> Response: + """Performs a post request to the LLM API. + + Args: + messages_dicts: List of role / content dicts to use as input. + stop: list of stop tokens + streaming: whether to do a streaming call + + Returns: + A request Response object + """ + # create request payload for openai compatible API + if "openai" in self.sambastudio_url: + messages_dicts = _create_message_dicts(messages) + data = { + "messages": messages_dicts, + "max_tokens": self.max_tokens, + "stop": stop, + "model": self.model, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "stream": streaming, + "stream_options": self.stream_options, + } + data = {key: value for key, value in data.items() if value is not None} + headers = { + "Authorization": f"Bearer " + f"{self.sambastudio_api_key.get_secret_value()}", + "Content-Type": "application/json", + } + + # create request payload for generic v1 API + elif "api/v2/predict/generic" in self.sambastudio_url: + items = [{"id": "item0", "value": self._messages_to_string(messages)}] + params: Dict[str, Any] = { + "select_expert": self.model, + "process_prompt": self.process_prompt, + "max_tokens_to_generate": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "do_sample": self.do_sample, + } + if self.model_kwargs is not None: + params = {**params, **self.model_kwargs} + params = {key: value for key, value in params.items() if value is not None} + data = {"items": items, "params": params} + headers = {"key": self.sambastudio_api_key.get_secret_value()} + + # create request payload for generic v1 API + elif "api/predict/generic" in self.sambastudio_url: + params = { + "select_expert": self.model, + "process_prompt": self.process_prompt, + "max_tokens_to_generate": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "do_sample": self.do_sample, + } + if self.model_kwargs is not None: + params = {**params, **self.model_kwargs} + params = { + key: {"type": type(value).__name__, "value": str(value)} + for key, value in params.items() + if value is not None + } + if streaming: + data = { + "instance": self._messages_to_string(messages), + "params": params, + } + else: + data = { + "instances": [self._messages_to_string(messages)], + "params": params, + } + headers = {"key": self.sambastudio_api_key.get_secret_value()} + + else: + raise ValueError( + f"Unsupported URL{self.sambastudio_url}" + "only openai, generic v1 and generic v2 APIs are supported" + ) + + http_session = requests.Session() + if streaming: + response = http_session.post( + self.streaming_url, headers=headers, json=data, stream=True + ) + else: + response = http_session.post( + self.base_url, headers=headers, json=data, stream=False + ) + if response.status_code != 200: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}." + f"{response.text}." + ) + return response + + async def _handle_request_async( + self, + messages: Sequence[ChatMessage], + stop: Optional[List[str]] = None, + streaming: Optional[bool] = False, + ) -> Response: + """Performs an async post request to the LLM API. + + Args: + messages_dicts: List of role / content dicts to use as input. + stop: list of stop tokens + streaming: whether to do a streaming call + + Returns: + A request Response object + """ + # create request payload for openai compatible API + if "openai" in self.sambastudio_url: + messages_dicts = _create_message_dicts(messages) + data = { + "messages": messages_dicts, + "max_tokens": self.max_tokens, + "stop": stop, + "model": self.model, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "stream": streaming, + "stream_options": self.stream_options, + } + data = {key: value for key, value in data.items() if value is not None} + headers = { + "Authorization": f"Bearer " + f"{self.sambastudio_api_key.get_secret_value()}", + "Content-Type": "application/json", + } + + # create request payload for generic v1 API + elif "api/v2/predict/generic" in self.sambastudio_url: + items = [{"id": "item0", "value": self._messages_to_string(messages)}] + params: Dict[str, Any] = { + "select_expert": self.model, + "process_prompt": self.process_prompt, + "max_tokens_to_generate": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "do_sample": self.do_sample, + } + if self.model_kwargs is not None: + params = {**params, **self.model_kwargs} + params = {key: value for key, value in params.items() if value is not None} + data = {"items": items, "params": params} + headers = {"key": self.sambastudio_api_key.get_secret_value()} + + # create request payload for generic v1 API + elif "api/predict/generic" in self.sambastudio_url: + params = { + "select_expert": self.model, + "process_prompt": self.process_prompt, + "max_tokens_to_generate": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "do_sample": self.do_sample, + } + if self.model_kwargs is not None: + params = {**params, **self.model_kwargs} + params = { + key: {"type": type(value).__name__, "value": str(value)} + for key, value in params.items() + if value is not None + } + if streaming: + data = { + "instance": self._messages_to_string(messages), + "params": params, + } + else: + data = { + "instances": [self._messages_to_string(messages)], + "params": params, + } + headers = {"key": self.sambastudio_api_key.get_secret_value()} + + else: + raise ValueError( + f"Unsupported URL{self.sambastudio_url}" + "only openai, generic v1 and generic v2 APIs are supported" + ) + + async with aiohttp.ClientSession() as session: + if streaming: + url = self.streaming_url + else: + url = self.base_url + + async with session.post( + url, + headers=headers, + json=data, + ) as response: + if response.status != 200: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status}." + f"{response.text}." + ) + response_dict = await response.json() + if response_dict.get("error"): + raise RuntimeError( + f"Sambanova /complete call failed with status code {response.status}.", + f"{response_dict}.", + ) + return response_dict + + def _process_response(self, response: Response) -> ChatMessage: + """Process a non streaming response from the api. + + Args: + response: A request Response object + + Returns: + generation: a ChatMessage with model generation + """ + # Extract json payload form response + try: + response_dict = response.json() + except Exception as e: + raise RuntimeError( + f"Sambanova /complete call failed couldn't get JSON response {e}" + f"response: {response.text}" + ) + + # process response payload for openai compatible API + if "openai" in self.sambastudio_url: + content = response_dict["choices"][0]["message"]["content"] + response_metadata = { + "finish_reason": response_dict["choices"][0]["finish_reason"], + "usage": response_dict.get("usage"), + "model_name": response_dict["model"], + "system_fingerprint": response_dict["system_fingerprint"], + "created": response_dict["created"], + } + + # process response payload for generic v2 API + elif "api/v2/predict/generic" in self.sambastudio_url: + content = response_dict["items"][0]["value"]["completion"] + response_metadata = response_dict["items"][0] + + # process response payload for generic v1 API + elif "api/predict/generic" in self.sambastudio_url: + content = response_dict["predictions"][0]["completion"] + response_metadata = response_dict + + else: + raise ValueError( + f"Unsupported URL{self.sambastudio_url}" + "only openai, generic v1 and generic v2 APIs are supported" + ) + + return ChatMessage( + content=content, + additional_kwargs=response_metadata, + role=MessageRole.ASSISTANT, + ) + + def _process_stream_response(self, response: Response) -> Iterator[ChatMessage]: + """Process a streaming response from the api. + + Args: + response: An iterable request Response object + + Yields: + generation: an Iterator[ChatMessage] with model partial generation + """ + try: + import sseclient + except ImportError: + raise ImportError( + "could not import sseclient library" + "Please install it with `pip install sseclient-py`." + ) + + # process response payload for openai compatible API + if "openai" in self.sambastudio_url: + finish_reason = "" + content = "" + client = sseclient.SSEClient(response) + for event in client.events(): + if event.event == "error_event": + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}." + f"{event.data}." + ) + try: + # check if the response is not a final event ("[DONE]") + if event.data != "[DONE]": + if isinstance(event.data, str): + data = json.loads(event.data) + else: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}." + f"{event.data}." + ) + if data.get("error"): + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}." + f"{event.data}." + ) + if len(data["choices"]) > 0: + finish_reason = data["choices"][0].get("finish_reason") + content += data["choices"][0]["delta"]["content"] + id = data["id"] + metadata = {} + else: + content += "" + id = data["id"] + metadata = { + "finish_reason": finish_reason, + "usage": data.get("usage"), + "model_name": data["model"], + "system_fingerprint": data["system_fingerprint"], + "created": data["created"], + } + if data.get("usage") is not None: + content += "" + id = data["id"] + metadata = { + "finish_reason": finish_reason, + "usage": data.get("usage"), + "model_name": data["model"], + "system_fingerprint": data["system_fingerprint"], + "created": data["created"], + } + yield ChatMessage( + role=MessageRole.ASSISTANT, + content=content, + additional_kwargs=metadata, + ) + + except Exception as e: + raise RuntimeError( + f"Error getting content chunk raw streamed response: {e}" + f"data: {event.data}" + ) + + # process response payload for generic v2 API + elif "api/v2/predict/generic" in self.sambastudio_url: + content = "" + for line in response.iter_lines(): + try: + data = json.loads(line) + content += data["result"]["items"][0]["value"]["stream_token"] + id = data["result"]["items"][0]["id"] + if data["result"]["items"][0]["value"]["is_last_response"]: + metadata = { + "finish_reason": data["result"]["items"][0]["value"].get( + "stop_reason" + ), + "prompt": data["result"]["items"][0]["value"].get("prompt"), + "usage": { + "prompt_tokens_count": data["result"]["items"][0][ + "value" + ].get("prompt_tokens_count"), + "completion_tokens_count": data["result"]["items"][0][ + "value" + ].get("completion_tokens_count"), + "total_tokens_count": data["result"]["items"][0][ + "value" + ].get("total_tokens_count"), + "start_time": data["result"]["items"][0]["value"].get( + "start_time" + ), + "end_time": data["result"]["items"][0]["value"].get( + "end_time" + ), + "model_execution_time": data["result"]["items"][0][ + "value" + ].get("model_execution_time"), + "time_to_first_token": data["result"]["items"][0][ + "value" + ].get("time_to_first_token"), + "throughput_after_first_token": data["result"]["items"][ + 0 + ]["value"].get("throughput_after_first_token"), + "batch_size_used": data["result"]["items"][0][ + "value" + ].get("batch_size_used"), + }, + } + else: + metadata = {} + yield ChatMessage( + role=MessageRole.ASSISTANT, + content=content, + additional_kwargs=metadata, + ) + + except Exception as e: + raise RuntimeError( + f"Error getting content chunk raw streamed response: {e}" + f"line: {line}" + ) + + # process response payload for generic v1 API + elif "api/predict/generic" in self.sambastudio_url: + content = "" + for line in response.iter_lines(): + try: + data = json.loads(line) + content += data["result"]["responses"][0]["stream_token"] + id = None + if data["result"]["responses"][0]["is_last_response"]: + metadata = { + "finish_reason": data["result"]["responses"][0].get( + "stop_reason" + ), + "prompt": data["result"]["responses"][0].get("prompt"), + "usage": { + "prompt_tokens_count": data["result"]["responses"][ + 0 + ].get("prompt_tokens_count"), + "completion_tokens_count": data["result"]["responses"][ + 0 + ].get("completion_tokens_count"), + "total_tokens_count": data["result"]["responses"][ + 0 + ].get("total_tokens_count"), + "start_time": data["result"]["responses"][0].get( + "start_time" + ), + "end_time": data["result"]["responses"][0].get( + "end_time" + ), + "model_execution_time": data["result"]["responses"][ + 0 + ].get("model_execution_time"), + "time_to_first_token": data["result"]["responses"][ + 0 + ].get("time_to_first_token"), + "throughput_after_first_token": data["result"][ + "responses" + ][0].get("throughput_after_first_token"), + "batch_size_used": data["result"]["responses"][0].get( + "batch_size_used" + ), + }, + } + else: + metadata = {} + yield ChatMessage( + role=MessageRole.ASSISTANT, + content=content, + additional_kwargs=metadata, + ) + + except Exception as e: + raise RuntimeError( + f"Error getting content chunk raw streamed response: {e}" + f"line: {line}" + ) + + else: + raise ValueError( + f"Unsupported URL{self.sambastudio_url}" + "only openai, generic v1 and generic v2 APIs are supported" + ) + + async def _process_response_async( + self, response_dict: Dict[str, Any] + ) -> ChatMessage: + """Process a non streaming response from the api. + + Args: + response: A request Response object + + Returns: + generation: a ChatMessage with model generation + """ + # process response payload for openai compatible API + if "openai" in self.sambastudio_url: + content = response_dict["choices"][0]["message"]["content"] + response_metadata = { + "finish_reason": response_dict["choices"][0]["finish_reason"], + "usage": response_dict.get("usage"), + "model_name": response_dict["model"], + "system_fingerprint": response_dict["system_fingerprint"], + "created": response_dict["created"], + } + + # process response payload for generic v2 API + elif "api/v2/predict/generic" in self.sambastudio_url: + content = response_dict["items"][0]["value"]["completion"] + response_metadata = response_dict["items"][0] + + # process response payload for generic v1 API + elif "api/predict/generic" in self.sambastudio_url: + content = response_dict["predictions"][0]["completion"] + response_metadata = response_dict + + else: + raise ValueError( + f"Unsupported URL{self.sambastudio_url}" + "only openai, generic v1 and generic v2 APIs are supported" + ) + + return ChatMessage( + content=content, + additional_kwargs=response_metadata, + role=MessageRole.ASSISTANT, + ) + + @llm_chat_callback() + def chat( + self, + messages: Sequence[ChatMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> ChatResponse: + """Calls the chat implementation of the SambaStudio model. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + If generation stops due to a stop token, the stop token itself + SHOULD BE INCLUDED as part of the output. This is not enforced + across models right now, but it's a good practice to follow since + it makes it much easier to parse the output of the model + downstream and understand why generation stopped. + + Returns: + ChatResponse with model generation + """ + # if self.streaming: + # stream_iter = self._stream( + # messages, stop=stop, **kwargs + # ) + # if stream_iter: + # return generate_from_stream(stream_iter) + response = self._handle_request(messages, stop, streaming=False) + message = self._process_response(response) + + return ChatResponse(message=message) + + @llm_chat_callback() + def stream_chat( + self, + messages: Sequence[ChatMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> ChatResponseGen: + """Stream the output of the SambaStudio model. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + If generation stops due to a stop token, the stop token itself + SHOULD BE INCLUDED as part of the output. This is not enforced + across models right now, but it's a good practice to follow since + it makes it much easier to parse the output of the model + downstream and understand why generation stopped. + + Yields: + chunk: ChatResponseGen with model partial generation + """ + response = self._handle_request(messages, stop, streaming=True) + for ai_message_chunk in self._process_stream_response(response): + chunk = ChatResponse(message=ai_message_chunk) + yield chunk + + @llm_completion_callback() + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + complete_fn = chat_to_completion_decorator(self.chat) + return complete_fn(prompt, **kwargs) + + @llm_completion_callback() + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat) + return stream_complete_fn(prompt, **kwargs) + + @llm_chat_callback() + async def achat( + self, + messages: Sequence[ChatMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> ChatResponse: + """Calls the chat implementation of the SambaStudio model. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + If generation stops due to a stop token, the stop token itself + SHOULD BE INCLUDED as part of the output. This is not enforced + across models right now, but it's a good practice to follow since + it makes it much easier to parse the output of the model + downstream and understand why generation stopped. + + Returns: + ChatResponse with model generation + """ + response_dict = await self._handle_request_async( + messages, stop, streaming=False + ) + message = await self._process_response_async(response_dict) + return ChatResponse(message=message) + + @llm_chat_callback() + async def astream_chat( + self, + messages: Sequence[ChatMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> ChatResponseAsyncGen: + raise NotImplementedError( + "SambaStudio does not currently support async streaming." + ) + + @llm_completion_callback() + async def acomplete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + acomplete_fn = achat_to_completion_decorator(self.achat) + return await acomplete_fn(prompt, **kwargs) + + @llm_completion_callback() + def astream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseAsyncGen: + raise NotImplementedError( + "SambaStudio does not currently support async streaming." + ) diff --git a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/pyproject.toml similarity index 79% rename from llama-index-integrations/llms/llama-index-llms-sambanovacloud/pyproject.toml rename to llama-index-integrations/llms/llama-index-llms-sambanovasystems/pyproject.toml index ce648b09869b4..2d7a68d2860bb 100644 --- a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/pyproject.toml @@ -9,10 +9,11 @@ skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" [tool.llamahub] contains_example = false -import_path = "llama_index.llms.sambanovacloud" +import_path = "llama_index.llms.sambanovasystems" [tool.llamahub.class_authors] -SambaNovaCloud = "Pradeep" +SambaNovaCloud = "rodrigo-92" +SambaStudio = "rodrigo-92" [tool.mypy] disallow_untyped_defs = true @@ -21,9 +22,9 @@ ignore_missing_imports = true python_version = "3.8" [tool.poetry] -authors = ["Your Name "] -description = "llama-index llms sambanova cloud integration" -name = "llama-index-llms-sambanovacloud" +authors = ["Rodrigo Maldonado "] +description = "llama-index llms sambanova cloud and sambastudio integration" +name = "llama-index-llms-sambanovasystems" readme = "README.md" version = "0.4.0" diff --git a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/tests/BUILD b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/tests/BUILD similarity index 100% rename from llama-index-integrations/llms/llama-index-llms-sambanovacloud/tests/BUILD rename to llama-index-integrations/llms/llama-index-llms-sambanovasystems/tests/BUILD diff --git a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/tests/__init__.py b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/tests/__init__.py similarity index 100% rename from llama-index-integrations/llms/llama-index-llms-sambanovacloud/tests/__init__.py rename to llama-index-integrations/llms/llama-index-llms-sambanovasystems/tests/__init__.py diff --git a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/tests/test_llms_sambanova.py b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/tests/test_llms_sambanovasystems.py similarity index 60% rename from llama-index-integrations/llms/llama-index-llms-sambanovacloud/tests/test_llms_sambanova.py rename to llama-index-integrations/llms/llama-index-llms-sambanovasystems/tests/test_llms_sambanovasystems.py index c8df197f80659..0d3da6f7a5c87 100644 --- a/llama-index-integrations/llms/llama-index-llms-sambanovacloud/tests/test_llms_sambanova.py +++ b/llama-index-integrations/llms/llama-index-llms-sambanovasystems/tests/test_llms_sambanovasystems.py @@ -1,19 +1,18 @@ -import time import asyncio +import time +import os +import pytest from llama_index.core.base.llms.types import ( ChatMessage, MessageRole, ) - -import os -import sys - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from llama_index.llms.sambanovacloud import SambaNovaCloud -import pytest +from llama_index.core.llms.llm import LLM +from llama_index.llms.sambanovasystems import SambaNovaCloud, SambaStudio sambanova_api_key = os.environ.get("SAMBANOVA_API_KEY", None) +sambastudio_url = os.environ.get("SAMBASTUDIO_URL", None) +sambastudio_api_key = os.environ.get("SAMBASTUDIO_API_KEY", None) @pytest.mark.asyncio() @@ -45,8 +44,8 @@ def get_execution_time(fn, chat_msgs, is_async=False, number=10): ) -@pytest.mark.skipif(not sambanova_api_key, reason="No openai api key set") -def test_sambanovacloud(): +@pytest.mark.skipif(not sambanova_api_key, reason="No api key set") +def test_calls(sambanova_client: LLM): # chat interaction example user_message = ChatMessage( role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence" @@ -56,32 +55,28 @@ def test_sambanovacloud(): user_message, ] - sambanovacloud_client = SambaNovaCloud() - # sync - print(f"chat response: {sambanovacloud_client.chat(chat_text_msgs)}\n") + print(f"chat response: {sambanova_client.chat(chat_text_msgs)}\n") print( - f"stream chat response: {[x.message.content for x in sambanovacloud_client.stream_chat(chat_text_msgs)]}\n" + f"stream chat response: {[x.message.content for x in sambanova_client.stream_chat(chat_text_msgs)]}\n" ) + print(f"complete response: {sambanova_client.complete(user_message.content)}\n") print( - f"complete response: {sambanovacloud_client.complete(user_message.content)}\n" - ) - print( - f"stream complete response: {[x.text for x in sambanovacloud_client.stream_complete(user_message.content)]}\n" + f"stream complete response: {[x.text for x in sambanova_client.stream_complete(user_message.content)]}\n" ) # async print( - f"async chat response: {asyncio.run(sambanovacloud_client.achat(chat_text_msgs))}\n" + f"async chat response: {asyncio.run(sambanova_client.achat(chat_text_msgs))}\n" ) print( - f"async complete response: {asyncio.run(sambanovacloud_client.acomplete(user_message.content))}\n" + f"async complete response: {asyncio.run(sambanova_client.acomplete(user_message.content))}\n" ) -@pytest.mark.skipif(not sambanova_api_key, reason="No openai api key set") -def test_sambanovacloud_performance(): +@pytest.mark.skipif(not sambanova_api_key, reason="No api key set") +def test_performance(sambanova_client: LLM): # chat interaction example user_message = ChatMessage( role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence" @@ -91,23 +86,19 @@ def test_sambanovacloud_performance(): user_message, ] - sambanovacloud_client = SambaNovaCloud() - # chat - get_execution_time(sambanovacloud_client.chat, chat_text_msgs, number=5) - get_execution_time( - sambanovacloud_client.achat, chat_text_msgs, is_async=True, number=5 - ) + get_execution_time(sambanova_client.chat, chat_text_msgs, number=5) + get_execution_time(sambanova_client.achat, chat_text_msgs, is_async=True, number=5) # complete - get_execution_time(sambanovacloud_client.complete, user_message.content, number=5) + get_execution_time(sambanova_client.complete, user_message.content, number=5) get_execution_time( - sambanovacloud_client.acomplete, user_message.content, is_async=True, number=5 + sambanova_client.acomplete, user_message.content, is_async=True, number=5 ) -@pytest.mark.skipif(not sambanova_api_key, reason="No openai api key set") -def test_hiperparameters(): +@pytest.mark.skipif(not sambanova_api_key, reason="No api key set") +def test_hiperparameters(sambanova_cls: LLM, testing_model: str): user_message = ChatMessage( role=MessageRole.USER, content="Tell me about Naruto Uzumaki in one sentence" ) @@ -116,29 +107,27 @@ def test_hiperparameters(): user_message, ] - model_list = ["llama3-8b", "llama3-70b"] max_tokens_list = [10, 100] temperature_list = [0, 1] top_p_list = [0, 1] top_k_list = [1, 50] stream_options_list = [{"include_usage": False}, {"include_usage": True}] - for model in model_list: - sambanovacloud_client = SambaNovaCloud( - model=model, - max_tokens=max_tokens_list[0], - temperature=temperature_list[0], - top_p=top_p_list[0], - top_k=top_k_list[0], - stream_options=stream_options_list[0], - ) - print( - f"model: {model}, generation: {sambanovacloud_client.chat(chat_text_msgs)}" - ) + sambanovacloud_client = sambanova_cls( + model=testing_model, + max_tokens=max_tokens_list[0], + temperature=temperature_list[0], + top_p=top_p_list[0], + top_k=top_k_list[0], + stream_options=stream_options_list[0], + ) + print( + f"model: {testing_model}, generation: {sambanovacloud_client.chat(chat_text_msgs)}" + ) for max_tokens in max_tokens_list: - sambanovacloud_client = SambaNovaCloud( - model=model_list[0], + sambanovacloud_client = sambanova_cls( + model=testing_model, max_tokens=max_tokens, temperature=temperature_list[0], top_p=top_p_list[0], @@ -150,8 +139,8 @@ def test_hiperparameters(): ) for temperature in temperature_list: - sambanovacloud_client = SambaNovaCloud( - model=model_list[0], + sambanovacloud_client = sambanova_cls( + model=testing_model, max_tokens=max_tokens_list[0], temperature=temperature, top_p=top_p_list[0], @@ -163,8 +152,8 @@ def test_hiperparameters(): ) for top_p in top_p_list: - sambanovacloud_client = SambaNovaCloud( - model=model_list[0], + sambanovacloud_client = sambanova_cls( + model=testing_model, max_tokens=max_tokens_list[0], temperature=temperature_list[0], top_p=top_p, @@ -176,8 +165,8 @@ def test_hiperparameters(): ) for top_k in top_k_list: - sambanovacloud_client = SambaNovaCloud( - model=model_list[0], + sambanovacloud_client = sambanova_cls( + model=testing_model, max_tokens=max_tokens_list[0], temperature=temperature_list[0], top_p=top_p_list[0], @@ -189,8 +178,8 @@ def test_hiperparameters(): ) for stream_options in stream_options_list: - sambanovacloud_client = SambaNovaCloud( - model=model_list[0], + sambanovacloud_client = sambanova_cls( + model=testing_model, max_tokens=max_tokens_list[0], temperature=temperature_list[0], top_p=top_p_list[0], @@ -202,7 +191,24 @@ def test_hiperparameters(): ) +@pytest.mark.skipif(not sambanova_api_key, reason="No api key set") +def test_sambanovacloud(): + testing_model = "llama3-8b" + sambanova_client = SambaNovaCloud() + test_calls(sambanova_client) + test_performance(sambanova_client) + test_hiperparameters(SambaNovaCloud, testing_model) + + +@pytest.mark.skipif(not sambastudio_api_key, reason="No api key set") +def test_sambastudio(): + testing_model = "Meta-Llama-3-70B-Instruct-4096" + sambanova_client = SambaStudio(model=testing_model) + test_calls(sambanova_client) + test_performance(sambanova_client) + test_hiperparameters(SambaStudio, testing_model) + + if __name__ == "__main__": test_sambanovacloud() - test_sambanovacloud_performance() - test_hiperparameters() + test_sambastudio()