Skip to content

Commit

Permalink
Removed the 'exclude text' from exa searcher
Browse files Browse the repository at this point in the history
  • Loading branch information
CodexVeritas committed Dec 11, 2024
1 parent 9a2de50 commit 1a8040a
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 68 deletions.
8 changes: 2 additions & 6 deletions code_tests/low_cost_or_live_api_tests/test_exa_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,12 @@ async def test_filtered_invoke() -> None:
model = ExaSearcher(
num_results=num_results, include_highlights=False, include_text=True
)
exclude_domains = ["alliance.health"]
search = SearchInput(
web_search_query="coronavirus",
highlight_query=None,
include_domains=[],
exclude_domains=exclude_domains,
exclude_domains=["alliance.health"],
include_text="pregnancy",
exclude_text="symptoms",
start_published_date=datetime(2022, 11, 1),
end_published_date=datetime(2022, 11, 30),
)
Expand All @@ -117,13 +115,11 @@ async def test_filtered_invoke() -> None:
assert source.published_date <= search.end_published_date
assert source.published_date >= search.start_published_date
assert search.include_text is not None
assert search.exclude_text is not None
assert search.include_text in source.text
assert search.exclude_text not in source.text
assert source.url is not None
assert all(
exclude_domain not in source.url
for exclude_domain in exclude_domains
for exclude_domain in search.exclude_domains
)
assert len(source.highlights) == 0
assert len(source.highlight_scores) == 0
Expand Down
40 changes: 13 additions & 27 deletions code_tests/low_cost_or_live_api_tests/test_metaculus_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,20 @@ def test_get_questions_from_tournament() -> None:
questions = MetaculusApi.get_all_open_questions_from_tournament(
ForecastingTestManager.TOURN_WITH_OPENNESS_AND_TYPE_VARIATIONS
)
assert any(isinstance(question, BinaryQuestion) for question in questions)
assert any(isinstance(question, NumericQuestion) for question in questions)
assert any(isinstance(question, DateQuestion) for question in questions)
assert any(
score = 0
if any(isinstance(question, BinaryQuestion) for question in questions):
score += 1
if any(isinstance(question, NumericQuestion) for question in questions):
score += 1
if any(isinstance(question, DateQuestion) for question in questions):
score += 1
if any(
isinstance(question, MultipleChoiceQuestion) for question in questions
)
):
score += 1
assert (
score > 1
), "There needs to be multiple question types in the tournament"

for question in questions:
assert question.state == QuestionState.OPEN
Expand Down Expand Up @@ -188,28 +196,6 @@ def test_get_benchmark_questions(num_questions_to_get: int) -> None:
), "Questions retrieved with same random seed should return same IDs"


def test_get_questions_from_current_quartely_cup() -> None:
expected_question_text = "Will BirdCast report 1 billion birds flying over the United States at any point before January 1, 2025?"
questions = (
MetaculusApi._get_open_binary_questions_from_current_quarterly_cup()
)

if ForecastingTestManager.quarterly_cup_is_not_active():
assert len(questions) == 0
else:
assert len(questions) > 0
assert any(
question.question_text == expected_question_text
for question in questions
)
assert all(
question.state == QuestionState.OPEN for question in questions
), "Expected all questions to be open"
assert all(
isinstance(question, BinaryQuestion) for question in questions
), "Expected all questions to be binary"


def assert_basic_question_attributes_not_none(
question: MetaculusQuestion, question_id: int
) -> None:
Expand Down
10 changes: 0 additions & 10 deletions code_tests/low_cost_or_live_api_tests/test_smart_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,6 @@ async def test_ask_question_empty_prompt() -> None:
await searcher.invoke("")


async def test_screenshot_question() -> None:
with MonetaryCostManager() as cost_manager:
searcher = SmartSearcher(num_sites_to_deep_dive=2)
question = "When was the most noticeable recent dip in the graph from https://fred.stlouisfed.org/series/GDPC1? Say 0 if you do not know. Please search specifically for the site itself."
result = await searcher.invoke(question)
logger.info(f"Result: {result}")
logger.info(f"Cost: {cost_manager.current_usage}")
assert "2020" in result


@pytest.mark.skip("Run this when needed as it's purely a qualitative test")
async def test_screenshot_question_2() -> None:
with MonetaryCostManager() as cost_manager:
Expand Down
6 changes: 0 additions & 6 deletions forecasting_tools/ai_models/exa_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ class SearchInput(BaseModel, Jsonable):
include_text: str | None = Field(
description="A 1-5 word phrase that must be exactly present in the text of the search results"
)
exclude_text: str | None = Field(
description="A 1-5 word phrase that must not be present in the text of the search results"
)
start_published_date: datetime | None = Field(
description="The earliest publication date for search results"
)
Expand Down Expand Up @@ -212,8 +209,6 @@ def _prepare_request_data(
payload["endPublishedDate"] = search.end_published_date.isoformat()
if search.include_text:
payload["includeText"] = [search.include_text]
if search.exclude_text:
payload["excludeText"] = [search.exclude_text]

return url, headers, payload

Expand All @@ -225,7 +220,6 @@ def __get_default_search_strategy(cls, search_query: str) -> SearchInput:
include_domains=[],
exclude_domains=[],
include_text=None,
exclude_text=None,
start_published_date=None,
end_published_date=None,
)
Expand Down
22 changes: 3 additions & 19 deletions forecasting_tools/forecasting/helpers/metaculus_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,24 +246,6 @@ def __get_general_open_binary_questions_resolving_in_3_months(
)
return checked_questions

@classmethod
def _get_open_binary_questions_from_current_quarterly_cup(
cls,
) -> list[BinaryQuestion]:
questions = cls.get_all_open_questions_from_tournament(
cls.CURRENT_QUARTERLY_CUP_ID,
)
binary_questions = [
question
for question in questions
if isinstance(question, BinaryQuestion)
]
assert all(
isinstance(question, BinaryQuestion)
for question in binary_questions
)
return binary_questions # type: ignore

@classmethod
def __get_questions_from_api(
cls, params: dict[str, Any], use_old_api: bool = False
Expand All @@ -284,7 +266,9 @@ def __get_questions_from_api(
supported_posts = [
q
for q in results
if "notebook" not in q and "group_of_questions" not in q
if "notebook" not in q
and "group_of_questions" not in q
and "conditional" not in q
]
removed_posts = [
post for post in results if post not in supported_posts
Expand Down

0 comments on commit 1a8040a

Please sign in to comment.