Skip to content

Commit

Permalink
fix more stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
scott-cohere committed Aug 13, 2024
1 parent 6237ad1 commit a124673
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/backend/routers/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ async def delete_file(
"""
user_id = ctx.get_user_id()
_ = validate_conversation(session, conversation_id, user_id)
validate_file(session, file_id, user_id, conversation_id)
validate_file(session, file_id, user_id, conversation_id, ctx)

# Delete the File DB object
get_file_service().delete_conversation_file_by_id(
Expand Down
6 changes: 5 additions & 1 deletion src/backend/services/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def get_files_in_compass(
list[File]: The files that were created
"""
compass = get_compass()
logger = ctx.get_logger()

files = []
for file_id in file_ids:
Expand All @@ -411,6 +412,9 @@ def get_files_in_compass(
parameters={"index": index, "file_id": file_id},
).result["doc"]["content"]
except Exception as e:
logger.error(
event=f"[Compass File Service] Error fetching file {file_id} on index {index} from Compass: {e}"
)
raise HTTPException(
status_code=404, detail=f"File with ID: {file_id} not found."
)
Expand Down Expand Up @@ -593,7 +597,7 @@ async def insert_files_in_compass(

# Misc
def validate_file(
session: DBSessionDep, file_id: str, user_id: str, ctx: Context, index: str = None
session: DBSessionDep, file_id: str, user_id: str, index: str, ctx: Context
) -> File:
"""Validates if a file exists and belongs to the user
Expand Down
16 changes: 0 additions & 16 deletions src/backend/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,19 +225,3 @@ def mock_available_model_deployments(request):

with patch.dict(AVAILABLE_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock:
yield mock


@pytest.fixture
def mock_compass_settings():
with patch("backend.services.file.Settings") as MockSettings:
mock_settings = MockSettings.return_value
mock_settings.feature_flags.use_compass_file_storage = os.getenv(
"ENABLE_COMPASS_FILE_STORAGE", "False"
).lower() in ("true", "1")
mock_settings.tools.compass.api_url = os.getenv("COHERE_COMPASS_API_URL")
mock_settings.tools.compass.api_parser_url = os.getenv(
"COHERE_COMPASS_API_PARSER_URL"
)
mock_settings.tools.compass.username = os.getenv("COHERE_COMPASS_USERNAME")
mock_settings.tools.compass.password = os.getenv("COHERE_COMPASS_PASSWORD")
yield mock_settings
16 changes: 16 additions & 0 deletions src/backend/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,19 @@ def mock_available_model_deployments(request):

with patch.dict(AVAILABLE_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock:
yield mock


@pytest.fixture
def mock_compass_settings():
with patch("backend.services.file.Settings") as MockSettings:
mock_settings = MockSettings.return_value
mock_settings.feature_flags.use_compass_file_storage = os.getenv(
"ENABLE_COMPASS_FILE_STORAGE", "False"
).lower() in ("true", "1")
mock_settings.tools.compass.api_url = os.getenv("COHERE_COMPASS_API_URL")
mock_settings.tools.compass.api_parser_url = os.getenv(
"COHERE_COMPASS_API_PARSER_URL"
)
mock_settings.tools.compass.username = os.getenv("COHERE_COMPASS_USERNAME")
mock_settings.tools.compass.password = os.getenv("COHERE_COMPASS_PASSWORD")
yield mock_settings
2 changes: 1 addition & 1 deletion src/backend/tests/unit/routers/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ def test_streaming_chat_with_files(
"files",
(
"Mariana_Trench.pdf",
open("src/backend/tests/test_data/Mariana_Trench.pdf", "rb"),
open("src/backend/tests/unit/test_data/Mariana_Trench.pdf", "rb"),
),
)
]
Expand Down
9 changes: 2 additions & 7 deletions src/backend/tests/unit/routers/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,11 +476,10 @@ def test_list_files(
"files",
(
"Mariana_Trench.pdf",
open("src/backend/tests/test_data/Mariana_Trench.pdf", "rb"),
open("src/backend/tests/unit/test_data/Mariana_Trench.pdf", "rb"),
),
)
]

response = session_client.post(
"/v1/conversations/batch_upload_file",
headers={"User-Id": conversation.user_id},
Expand Down Expand Up @@ -531,7 +530,6 @@ def test_upload_file_existing_conversation(
session_client: TestClient, session: Session, user: User, mock_compass_settings
) -> None:
file_path = "src/backend/tests/unit/test_data/Mariana_Trench.pdf"
saved_file_path = "src/backend/data/Mariana_Trench.pdf"
conversation = get_factory("Conversation", session).create(user_id=user.id)
file_doc = {"file": open(file_path, "rb")}

Expand All @@ -556,7 +554,6 @@ def test_upload_file_nonexistent_conversation_creates_new_conversation(
session_client: TestClient, session: Session, user: User, mock_compass_settings
) -> None:
file_path = "src/backend/tests/unit/test_data/Mariana_Trench.pdf"
saved_file_path = "src/backend/data/Mariana_Trench.pdf"
file_doc = {"file": open(file_path, "rb")}

response = session_client.post(
Expand Down Expand Up @@ -606,8 +603,6 @@ def test_batch_upload_file_existing_conversation(
file_paths = {
"Mariana_Trench.pdf": "src/backend/tests/unit/test_data/Mariana_Trench.pdf",
"Cardistry.pdf": "src/backend/tests/unit/test_data/Cardistry.pdf",
"Tapas.pdf": "src/backend/tests/unit/test_data/Tapas.pdf",
"Mount_Everest.pdf": "src/backend/tests/unit/test_data/Mount_Everest.pdf",
}
files = [
("files", (file_name, open(file_path, "rb")))
Expand Down Expand Up @@ -800,7 +795,7 @@ def test_delete_file(
"files",
(
"Mariana_Trench.pdf",
open("src/backend/tests/test_data/Mariana_Trench.pdf", "rb"),
open("src/backend/tests/unit/test_data/Mariana_Trench.pdf", "rb"),
),
)
]
Expand Down

0 comments on commit a124673

Please sign in to comment.