Skip to content

Commit

Permalink
feat(backend): remove dependencies on cohere api key
Browse files Browse the repository at this point in the history
  • Loading branch information
ezawadski committed Jan 24, 2025
1 parent f629cbb commit 0cec72a
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 25 deletions.
7 changes: 6 additions & 1 deletion src/backend/chat/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,12 @@ async def rerank_and_chunk(
return list(reranked_results.values())


def chunk(content, compact_mode=False, soft_word_cut_off=100, hard_word_cut_off=300):
def chunk(
content: str,
compact_mode: bool = False,
soft_word_cut_off: int = 100,
hard_word_cut_off: int = 300,
) -> list[str]:
if compact_mode:
content = content.replace("\n", " ")

Expand Down
30 changes: 15 additions & 15 deletions src/backend/tests/unit/tools/test_collate.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import os

import pytest

from backend.chat import collate
from backend.model_deployments import CohereDeployment

is_cohere_env_set = (
os.environ.get("COHERE_API_KEY") is not None
and os.environ.get("COHERE_API_KEY") != ""
)
from backend.schemas.context import Context


@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
@pytest.mark.asyncio
async def test_rerank() -> None:
model = CohereDeployment(model_config={})
outputs = [
Expand Down Expand Up @@ -73,19 +67,21 @@ async def test_rerank() -> None:
},
]

assert await collate.rerank_and_chunk(tool_results, model) == expected_output
assert await collate.rerank_and_chunk(tool_results, model, Context()) == expected_output


def test_chunk_normal_mode() -> None:
content = "This is a test. We are testing the chunk function."
expected_output = ["This is a test.", "We are testing the chunk function."]
collate.chunk(content, False, 4, 10) == expected_output
output = collate.chunk(content, False, 3, 10)
assert output == expected_output


def test_chunk_compact_mode() -> None:
content = "This is a test.\nWe are testing the chunk function."
expected_output = ["This is a test.", "We are testing the chunk function."]
collate.chunk(content, True, 4, 10) == expected_output
output = collate.chunk(content, True, 3, 10)
assert output == expected_output


def test_chunk_hard_cut_off() -> None:
Expand All @@ -94,19 +90,23 @@ def test_chunk_hard_cut_off() -> None:
"This is a test. We are testing the chunk function.",
"This sentence will exceed the hard cut off.",
]
collate.chunk(content, False, 4, 10) == expected_output
output = collate.chunk(content, False, 11, 10)
assert output == expected_output


def test_chunk_soft_cut_off() -> None:
content = "This is a test. We are testing the chunk function. This sentence will exceed the soft cut off."
expected_output = [
"This is a test.",
"We are testing the chunk function. This sentence will exceed the soft cut off.",
"We are testing the chunk function.",
"This sentence will exceed the soft cut off.",
]
collate.chunk(content, False, 4, 10) == expected_output
output = collate.chunk(content, False, 3, 10)
assert output == expected_output


def test_chunk_empty_content() -> None:
content = ""
expected_output = []
collate.chunk(content, False, 4, 10) == expected_output
output = collate.chunk(content, False, 3, 10)
assert output == expected_output
9 changes: 0 additions & 9 deletions src/backend/tests/unit/tools/test_lang_chain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -8,11 +7,6 @@
from backend.tools import LangChainVectorDBRetriever, LangChainWikiRetriever
from backend.tools.base import ToolError, ToolErrorCode

is_cohere_env_set = (
os.environ.get("COHERE_API_KEY") is not None
and os.environ.get("COHERE_API_KEY") != ""
)


@pytest.mark.asyncio
async def test_wiki_retriever() -> None:
Expand Down Expand Up @@ -62,7 +56,6 @@ async def test_wiki_retriever() -> None:
assert result == expected_docs


@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
@pytest.mark.asyncio
async def test_wiki_retriever_no_docs() -> None:
ctx = Context()
Expand All @@ -83,7 +76,6 @@ async def test_wiki_retriever_no_docs() -> None:



@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
@pytest.mark.asyncio
async def test_vector_db_retriever() -> None:
ctx = Context()
Expand Down Expand Up @@ -145,7 +137,6 @@ async def test_vector_db_retriever() -> None:
assert result == expected_docs


@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
@pytest.mark.asyncio
async def test_vector_db_retriever_no_docs() -> None:
ctx = Context()
Expand Down

0 comments on commit 0cec72a

Please sign in to comment.