-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #933 from ScrapeGraphAI/codebeaver/pre/beta-932
Pre/beta - Unit Tests
- Loading branch information
Showing
1 changed file
with
270 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
import json | ||
import pytest | ||
from langchain.prompts import ( | ||
PromptTemplate, | ||
) | ||
from langchain_community.chat_models import ( | ||
ChatOllama, | ||
) | ||
from langchain_core.runnables import ( | ||
RunnableParallel, | ||
) | ||
from requests.exceptions import ( | ||
Timeout, | ||
) | ||
from scrapegraphai.nodes.generate_answer_node import ( | ||
GenerateAnswerNode, | ||
) | ||
|
||
|
||
class DummyLLM: | ||
|
||
def __call__(self, *args, **kwargs): | ||
return "dummy response" | ||
|
||
|
||
class DummyLogger: | ||
|
||
def info(self, msg): | ||
pass | ||
|
||
def error(self, msg): | ||
pass | ||
|
||
|
||
@pytest.fixture | ||
def dummy_node(): | ||
""" | ||
Fixture for a GenerateAnswerNode instance using DummyLLM. | ||
Uses a valid input keys string ("dummy_input & doc") to avoid parsing errors. | ||
""" | ||
node_config = {"llm_model": DummyLLM(), "verbose": False, "timeout": 1} | ||
node = GenerateAnswerNode("dummy_input & doc", ["output"], node_config=node_config) | ||
node.logger = DummyLogger() | ||
node.get_input_keys = lambda state: ["dummy_input", "doc"] | ||
return node | ||
|
||
|
||
def test_process_missing_content_and_user_prompt(dummy_node): | ||
""" | ||
Test that process() raises a ValueError when either the content or the user prompt is missing. | ||
""" | ||
state_missing_content = {"user_prompt": "What is the answer?"} | ||
with pytest.raises(ValueError) as excinfo1: | ||
dummy_node.process(state_missing_content) | ||
assert "No content found in state" in str(excinfo1.value) | ||
state_missing_prompt = {"content": "Some valid context content"} | ||
with pytest.raises(ValueError) as excinfo2: | ||
dummy_node.process(state_missing_prompt) | ||
assert "No user prompt found in state" in str(excinfo2.value) | ||
|
||
|
||
class DummyLLMWithPipe: | ||
"""DummyLLM that supports the pipe '|' operator. | ||
When used in a chain with a PromptTemplate, the pipe operator returns self, | ||
simulating chain composition.""" | ||
|
||
def __or__(self, other): | ||
return self | ||
|
||
def __call__(self, *args, **kwargs): | ||
return {"content": "script single-chunk answer"} | ||
|
||
|
||
@pytest.fixture | ||
def dummy_node_with_pipe(): | ||
""" | ||
Fixture for a GenerateAnswerNode instance using DummyLLMWithPipe. | ||
Uses a valid input keys string ("dummy_input & doc") to avoid parsing errors. | ||
""" | ||
node_config = {"llm_model": DummyLLMWithPipe(), "verbose": False, "timeout": 480} | ||
node = GenerateAnswerNode("dummy_input & doc", ["output"], node_config=node_config) | ||
node.logger = DummyLogger() | ||
node.get_input_keys = lambda state: ["dummy_input", "doc"] | ||
return node | ||
|
||
|
||
def test_execute_multiple_chunks(dummy_node_with_pipe): | ||
""" | ||
Test the execute() method for a scenario with multiple document chunks. | ||
It simulates parallel processing of chunks and then merges them. | ||
""" | ||
state = { | ||
"dummy_input": "What is the final answer?", | ||
"doc": ["Chunk text 1", "Chunk text 2"], | ||
} | ||
|
||
def fake_invoke_with_timeout(chain, inputs, timeout): | ||
if isinstance(chain, RunnableParallel): | ||
return { | ||
"chunk1": {"content": "answer for chunk 1"}, | ||
"chunk2": {"content": "answer for chunk 2"}, | ||
} | ||
if "context" in inputs and "question" in inputs: | ||
return {"content": "merged final answer"} | ||
return {"content": "single answer"} | ||
|
||
dummy_node_with_pipe.invoke_with_timeout = fake_invoke_with_timeout | ||
output_state = dummy_node_with_pipe.execute(state) | ||
assert output_state["output"] == {"content": "merged final answer"} | ||
|
||
|
||
def test_execute_single_chunk(dummy_node_with_pipe): | ||
""" | ||
Test the execute() method for a single document chunk. | ||
""" | ||
state = {"dummy_input": "What is the answer?", "doc": ["Only one chunk text"]} | ||
|
||
def fake_invoke_with_timeout(chain, inputs, timeout): | ||
if "question" in inputs: | ||
return {"content": "single-chunk answer"} | ||
return {"content": "unexpected result"} | ||
|
||
dummy_node_with_pipe.invoke_with_timeout = fake_invoke_with_timeout | ||
output_state = dummy_node_with_pipe.execute(state) | ||
assert output_state["output"] == {"content": "single-chunk answer"} | ||
|
||
|
||
def test_execute_merge_json_decode_error(dummy_node_with_pipe): | ||
""" | ||
Test that execute() handles a JSONDecodeError in the merge chain properly. | ||
""" | ||
state = { | ||
"dummy_input": "What is the final answer?", | ||
"doc": ["Chunk 1 text", "Chunk 2 text"], | ||
} | ||
|
||
def fake_invoke_with_timeout(chain, inputs, timeout): | ||
if isinstance(chain, RunnableParallel): | ||
return { | ||
"chunk1": {"content": "answer for chunk 1"}, | ||
"chunk2": {"content": "answer for chunk 2"}, | ||
} | ||
if "context" in inputs and "question" in inputs: | ||
raise json.JSONDecodeError("Invalid JSON", "", 0) | ||
return {"content": "unexpected response"} | ||
|
||
dummy_node_with_pipe.invoke_with_timeout = fake_invoke_with_timeout | ||
output_state = dummy_node_with_pipe.execute(state) | ||
assert "error" in output_state["output"] | ||
assert ( | ||
"Invalid JSON response format during merge" in output_state["output"]["error"] | ||
) | ||
|
||
|
||
class DummyChain: | ||
"""A dummy chain for simulating a chain's invoke behavior. | ||
Returns a successful answer in the expected format.""" | ||
|
||
def invoke(self, inputs): | ||
return {"content": "successful answer"} | ||
|
||
|
||
@pytest.fixture | ||
def dummy_node_for_process(): | ||
""" | ||
Fixture for creating a GenerateAnswerNode instance for testing the process() method success case. | ||
""" | ||
node_config = {"llm_model": DummyChain(), "verbose": False, "timeout": 1} | ||
node = GenerateAnswerNode( | ||
"user_prompt & content", ["output"], node_config=node_config | ||
) | ||
node.logger = DummyLogger() | ||
node.get_input_keys = lambda state: ["user_prompt", "content"] | ||
return node | ||
|
||
|
||
def test_process_success(dummy_node_for_process): | ||
""" | ||
Test that process() successfully generates an answer when both user prompt and content are provided. | ||
""" | ||
state = { | ||
"user_prompt": "What is the answer?", | ||
"content": "This is some valid context.", | ||
} | ||
dummy_node_for_process.chain = DummyChain() | ||
dummy_node_for_process.invoke_with_timeout = ( | ||
lambda chain, inputs, timeout: chain.invoke(inputs) | ||
) | ||
new_state = dummy_node_for_process.process(state) | ||
assert new_state["output"] == {"content": "successful answer"} | ||
|
||
|
||
def test_execute_timeout_single_chunk(dummy_node_with_pipe): | ||
""" | ||
Test that execute() properly handles a Timeout exception in the single chunk branch. | ||
""" | ||
state = {"dummy_input": "What is the answer?", "doc": ["Only one chunk text"]} | ||
|
||
def fake_invoke_timeout(chain, inputs, timeout): | ||
raise Timeout("Simulated timeout error") | ||
|
||
dummy_node_with_pipe.invoke_with_timeout = fake_invoke_timeout | ||
output_state = dummy_node_with_pipe.execute(state) | ||
assert "error" in output_state["output"] | ||
assert "Response timeout exceeded" in output_state["output"]["error"] | ||
assert "Simulated timeout error" in output_state["output"]["raw_response"] | ||
|
||
|
||
def test_execute_script_creator_single_chunk(): | ||
""" | ||
Test the execute() method for the scenario when script_creator mode is enabled. | ||
This verifies that the non-markdown prompt templates branch is executed and the expected answer is generated. | ||
""" | ||
node_config = { | ||
"llm_model": DummyLLMWithPipe(), | ||
"verbose": False, | ||
"timeout": 480, | ||
"script_creator": True, | ||
"force": False, | ||
"is_md_scraper": False, | ||
"additional_info": "TEST INFO: ", | ||
} | ||
node = GenerateAnswerNode("dummy_input & doc", ["output"], node_config=node_config) | ||
node.logger = DummyLogger() | ||
node.get_input_keys = lambda state: ["dummy_input", "doc"] | ||
state = { | ||
"dummy_input": "What is the script answer?", | ||
"doc": ["Only one chunk script"], | ||
} | ||
|
||
def fake_invoke_with_timeout(chain, inputs, timeout): | ||
if "question" in inputs: | ||
return {"content": "script single-chunk answer"} | ||
return {"content": "unexpected response"} | ||
|
||
node.invoke_with_timeout = fake_invoke_with_timeout | ||
output_state = node.execute(state) | ||
assert output_state["output"] == {"content": "script single-chunk answer"} | ||
|
||
|
||
class DummyChatOllama(ChatOllama): | ||
"""A dummy ChatOllama class to simulate ChatOllama behavior.""" | ||
|
||
|
||
class DummySchema: | ||
"""A dummy schema class with a model_json_schema method.""" | ||
|
||
def model_json_schema(self): | ||
return "dummy_schema_json" | ||
|
||
|
||
def test_init_chat_ollama_format(): | ||
""" | ||
Test that the __init__ method of GenerateAnswerNode sets the format attribute of a ChatOllama LLM correctly. | ||
""" | ||
dummy_llm = DummyChatOllama() | ||
node_config = {"llm_model": dummy_llm, "verbose": False, "timeout": 1} | ||
node = GenerateAnswerNode("dummy_input", ["output"], node_config=node_config) | ||
assert node.llm_model.format == "json" | ||
dummy_llm_with_schema = DummyChatOllama() | ||
node_config_with_schema = { | ||
"llm_model": dummy_llm_with_schema, | ||
"verbose": False, | ||
"timeout": 1, | ||
"schema": DummySchema(), | ||
} | ||
node2 = GenerateAnswerNode( | ||
"dummy_input", ["output"], node_config=node_config_with_schema | ||
) | ||
assert node2.llm_model.format == "dummy_schema_json" |