Skip to content

Commit

Permalink
Merge pull request #912 from ScrapeGraphAI/codebeaver/pre/beta-904
Browse files Browse the repository at this point in the history
codebeaver/pre/beta-904 - Unit Tests
  • Loading branch information
VinciGit00 authored Feb 6, 2025
2 parents c002bf4 + 80dd766 commit 948164f
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 2 deletions.
34 changes: 34 additions & 0 deletions tests/test_depth_search_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from unittest.mock import patch, MagicMock
from scrapegraphai.graphs.depth_search_graph import DepthSearchGraph
from scrapegraphai.graphs.abstract_graph import AbstractGraph
import pytest


class TestDepthSearchGraph:
"""Test suite for DepthSearchGraph class"""

@pytest.mark.parametrize(
"source, expected_input_key",
[
("https://example.com", "url"),
("/path/to/local/directory", "local_dir"),
],
)
def test_depth_search_graph_initialization(self, source, expected_input_key):
"""
Test that DepthSearchGraph initializes correctly with different source types.
This test verifies that the input_key is set to 'url' for web sources and
'local_dir' for local directory sources.
"""
prompt = "Test prompt"
config = {"llm": {"model": "mock_model"}}

# Mock both BaseGraph and _create_llm method
with patch("scrapegraphai.graphs.depth_search_graph.BaseGraph"), \
patch.object(AbstractGraph, '_create_llm', return_value=MagicMock()):
graph = DepthSearchGraph(prompt, source, config)

assert graph.prompt == prompt
assert graph.source == source
assert graph.config == config
assert graph.input_key == expected_input_key
58 changes: 57 additions & 1 deletion tests/test_json_scraper_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from pydantic import BaseModel
from pydantic import BaseModel, Field
from scrapegraphai.graphs.json_scraper_graph import JSONScraperGraph
from unittest.mock import Mock, patch

Expand Down Expand Up @@ -133,4 +133,60 @@ def test_json_scraper_graph_no_answer_found(self, mock_create_llm, mock_generate
mock_execute.assert_called_once_with({"user_prompt": "Query that produces no answer", "json": "path/to/empty/file.json"})
mock_fetch_node.assert_called_once()
mock_generate_answer_node.assert_called_once()
mock_create_llm.assert_called_once_with({"model": "test-model", "temperature": 0})

@pytest.fixture
def mock_llm_model(self):
return Mock()

@pytest.fixture
def mock_embedder_model(self):
return Mock()

@patch('scrapegraphai.graphs.json_scraper_graph.FetchNode')
@patch('scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode')
@patch.object(JSONScraperGraph, '_create_llm')
def test_json_scraper_graph_with_custom_schema(self, mock_create_llm, mock_generate_answer_node, mock_fetch_node, mock_llm_model, mock_embedder_model):
"""
Test JSONScraperGraph with a custom schema.
This test checks if the graph correctly handles a custom schema input
and passes it to the GenerateAnswerNode.
"""
# Define a custom schema
class CustomSchema(BaseModel):
name: str = Field(..., description="Name of the attraction")
description: str = Field(..., description="Description of the attraction")

# Mock the _create_llm method to return a mock LLM model
mock_create_llm.return_value = mock_llm_model

# Mock the execute method of BaseGraph
with patch('scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute') as mock_execute:
mock_execute.return_value = ({"answer": "Mocked answer with custom schema"}, {})

# Create a JSONScraperGraph instance with a custom schema
graph = JSONScraperGraph(
prompt="List attractions in Chioggia",
source="path/to/chioggia.json",
config={"llm": {"model": "test-model", "temperature": 0}},
schema=CustomSchema
)

# Set mocked embedder model
graph.embedder_model = mock_embedder_model

# Run the graph
result = graph.run()

# Assertions
assert result == "Mocked answer with custom schema"
assert graph.input_key == "json"
mock_execute.assert_called_once_with({"user_prompt": "List attractions in Chioggia", "json": "path/to/chioggia.json"})
mock_fetch_node.assert_called_once()
mock_generate_answer_node.assert_called_once()

# Check if the custom schema was passed to GenerateAnswerNode
generate_answer_node_call = mock_generate_answer_node.call_args[1]
assert generate_answer_node_call['node_config']['schema'] == CustomSchema

mock_create_llm.assert_called_once_with({"model": "test-model", "temperature": 0})
27 changes: 26 additions & 1 deletion tests/test_search_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,29 @@ def test_max_results_config(self, mock_create_llm, mock_base_graph, mock_merge_a
# Assert
mock_search_internet.assert_called_once()
call_args = mock_search_internet.call_args
assert call_args.kwargs['node_config']['max_results'] == max_results
assert call_args.kwargs['node_config']['max_results'] == max_results

@patch('scrapegraphai.graphs.search_graph.SearchInternetNode')
@patch('scrapegraphai.graphs.search_graph.GraphIteratorNode')
@patch('scrapegraphai.graphs.search_graph.MergeAnswersNode')
@patch('scrapegraphai.graphs.search_graph.BaseGraph')
@patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm')
def test_custom_search_engine_config(self, mock_create_llm, mock_base_graph, mock_merge_answers, mock_graph_iterator, mock_search_internet):
"""
Test that the custom search_engine parameter from the config is correctly passed to the SearchInternetNode.
"""
# Arrange
prompt = "Test prompt"
custom_search_engine = "custom_engine"
config = {
"llm": {"model": "test-model"},
"search_engine": custom_search_engine
}

# Act
search_graph = SearchGraph(prompt, config)

# Assert
mock_search_internet.assert_called_once()
call_args = mock_search_internet.call_args
assert call_args.kwargs['node_config']['search_engine'] == custom_search_engine

0 comments on commit 948164f

Please sign in to comment.