diff --git a/sweepai/core/context_pruning.py b/sweepai/core/context_pruning.py index 0d035555df..94b0d48635 100644 --- a/sweepai/core/context_pruning.py +++ b/sweepai/core/context_pruning.py @@ -5,6 +5,7 @@ import subprocess import urllib from dataclasses import dataclass, field +from unittest.mock import MagicMock, patch import networkx as nx import openai @@ -1048,6 +1049,46 @@ def context_dfs( tracking_id="test", ) repo_context_manager = prep_snippets(cloned_repo, query, ticket_progress) + + # Test handle_function_call + function_call = AnthropicFunctionCall( + function_name="store_file", + function_parameters={ + "file_path": "sweepai/config/client.py", + "justification": "This file contains the logic for reading the sweep.yaml config and needs to be modified." + } + ) + llm_state = {} + output = handle_function_call(repo_context_manager, function_call, llm_state) + assert "SUCCESS" in output + assert any(snippet.file_path == "sweepai/config/client.py" for snippet in repo_context_manager.current_top_snippets) + print("handle_function_call test passed") + + # Test get_relevant_context + with patch("sweepai.core.context_pruning.context_dfs") as mock_context_dfs: + mock_rcm = MagicMock() + mock_rcm.current_top_snippets = [ + Snippet(file_path="sweepai/config/client.py", start=0, end=10, content="mock content"), + Snippet(file_path="sweepai/handlers/create_pr.py", start=0, end=10, content="mock content") + ] + mock_context_dfs.return_value = mock_rcm + + test_query = "Test query" + test_rcm = RepoContextManager( + dir_obj=MagicMock(), + current_top_tree="", + snippets=[], + snippet_scores={}, + cloned_repo=MagicMock() + ) + + result_rcm = get_relevant_context(test_query, test_rcm) + + assert len(result_rcm.current_top_snippets) == 2 + assert any(snippet.file_path == "sweepai/config/client.py" for snippet in result_rcm.current_top_snippets) + assert any(snippet.file_path == "sweepai/handlers/create_pr.py" for snippet in result_rcm.current_top_snippets) + print("get_relevant_context test passed") + rcm = get_relevant_context( query, repo_context_manager,