Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Feb 27, 2025
1 parent 9e019f1 commit ae7efd3
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 108 deletions.
143 changes: 43 additions & 100 deletions olmocr/bench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,91 +13,31 @@

import argparse
import glob
import itertools
import json
import os
import sys
from typing import Tuple, List, Dict

from fuzzysearch import find_near_matches
from rapidfuzz import fuzz
from .tests import BasePDFTest, load_tests

from .datatypes import BasePDFTest, TextPresenceTest, TextOrderTest, TestType, load_tests

def run_rule(rule: BasePDFTest, md_content: str) -> Tuple[bool, str]:
"""
Run the given rule on the content of the provided .md file.
Returns a tuple (passed, explanation) where 'passed' is True if the rule passes,
and 'explanation' is a short message explaining the failure when the rule does not pass.
"""
rule_type = rule.type

if rule_type in (TestType.PRESENT.value, TestType.ABSENT.value):
# This is a TextPresenceTest
if not isinstance(rule, TextPresenceTest):
return (False, f"Rule type mismatch: expected TextPresenceTest but got {type(rule).__name__}")

reference_query = rule.text
threshold = rule.threshold
best_ratio = fuzz.partial_ratio(reference_query, md_content) / 100.0

if rule_type == TestType.PRESENT.value:
if best_ratio >= threshold:
return (True, "")
else:
return (False, f"Expected '{reference_query[:40]}...' with threshold {threshold} but best match ratio was {best_ratio:.3f}")
else: # absent
if best_ratio < threshold:
return (True, "")
else:
return (False, f"Expected absence of '{reference_query[:40]}...' with threshold {threshold} but best match ratio was {best_ratio:.3f}")

elif rule_type == TestType.ORDER.value:
# This is a TextOrderTest
if not isinstance(rule, TextOrderTest):
return (False, f"Rule type mismatch: expected TextOrderTest but got {type(rule).__name__}")

before = rule.before
after = rule.after
threshold = rule.threshold
max_l_dist = round((1.0 - threshold) * len(before))

before_matches = find_near_matches(before, md_content, max_l_dist=max_l_dist)
after_matches = find_near_matches(after, md_content, max_l_dist=max_l_dist)

if not before_matches:
return (False, f"'before' search text '{before[:40]}...' not found with max_l_dist {max_l_dist}")
if not after_matches:
return (False, f"'after' search text '{after[:40]}...' not found with max_l_dist {max_l_dist}")

for before_match, after_match in itertools.product(before_matches, after_matches):
if before_match.start < after_match.start:
return (True, "")

return (False, f"Could not find a location where '{before[:40]}...' appears before '{after[:40]}...'.")

else:
raise NotImplementedError(f"Rule type '{rule_type}' is not implemented.")


def evaluate_candidate(candidate_folder: str, all_rules: List[BasePDFTest], pdf_basenames: List[str]) -> Tuple[float, int, List[str], List[str], Dict[str, List[float]]]:
def evaluate_candidate(candidate_folder: str, all_tests: List[BasePDFTest], pdf_basenames: List[str]) -> Tuple[float, int, List[str], List[str], Dict[str, List[float]]]:
"""
For the candidate folder (pipeline tool output), validate that it contains at least one .md file
(i.e. repeated generations like _1.md, _2.md, etc.) for every PDF in the pdf folder.
Then, run each rule against all corresponding .md files and average the results.
Returns a tuple:
(overall_score, total_rules, candidate_errors, rule_failures, rule_type_breakdown)
(overall_score, total_tests, candidate_errors, test_failures, test_type_breakdown)
- overall_score: Average fraction of rules passed (averaged over repeats and rules).
- total_rules: Total number of rules evaluated.
- overall_score: Average fraction of tests passed (averaged over repeats and tests).
- total_tests: Total number of tests evaluated.
- candidate_errors: List of candidate errors (e.g. missing files).
- rule_failures: List of failure messages for rules not passing on all repeats.
- rule_type_breakdown: Dictionary mapping rule type to list of average pass ratios for rules of that type.
- test_failures: List of failure messages for tests not passing on all repeats.
- test_type_breakdown: Dictionary mapping test type to list of average pass ratios for tests of that type.
"""
candidate_errors = []
rule_failures = []
rule_type_breakdown = {} # key: rule type, value: list of average pass ratios
test_failures = []
test_type_breakdown = {} # key: test type, value: list of average pass ratios
candidate_name = os.path.basename(candidate_folder)

# Map each PDF to its corresponding MD repeats (e.g., doc1_1.md, doc1_2.md, etc.)
Expand All @@ -112,16 +52,16 @@ def evaluate_candidate(candidate_folder: str, all_rules: List[BasePDFTest], pdf_
pdf_to_md_files[pdf_name] = md_files

if candidate_errors:
return (0.0, len(all_rules), candidate_errors, rule_failures, rule_type_breakdown)
return (0.0, len(all_tests), candidate_errors, test_failures, test_type_breakdown)

total_rule_score = 0.0
total_test_score = 0.0

# Evaluate each rule. Each rule references a PDF (e.g., "doc1.pdf") so we get all its MD repeats.
for rule in all_rules:
rule_type = rule.type
if rule_type not in rule_type_breakdown:
rule_type_breakdown[rule_type] = []
pdf_name = rule.pdf
# Evaluate each test. Each test references a PDF (e.g., "doc1.pdf") so we get all its MD repeats.
for test in all_tests:
test_type = test.type
if test_type not in test_type_breakdown:
test_type_breakdown[test_type] = []
pdf_name = test.pdf
md_base = os.path.splitext(pdf_name)[0]
md_files = pdf_to_md_files.get(pdf_name, [])
if not md_files:
Expand All @@ -139,26 +79,27 @@ def evaluate_candidate(candidate_folder: str, all_rules: List[BasePDFTest], pdf_
continue

try:
passed, explanation = run_rule(rule, md_content)
# Use the test's run method to evaluate the content
passed, explanation = test.run(md_content)
if passed:
repeat_passes += 1
else:
explanations.append(explanation)
except Exception as e:
candidate_errors.append(f"Error running rule {rule.id} on {md_path}: {e}")
candidate_errors.append(f"Error running test {test.id} on {md_path}: {e}")
explanations.append(str(e))

rule_avg = repeat_passes / num_repeats if num_repeats > 0 else 0.0
total_rule_score += rule_avg
if rule_avg < 1.0:
rule_failures.append(
f"Rule {rule.id} on {md_base} average pass ratio: {rule_avg:.3f} ({repeat_passes}/{num_repeats} repeats passed). "
test_avg = repeat_passes / num_repeats if num_repeats > 0 else 0.0
total_test_score += test_avg
if test_avg < 1.0:
test_failures.append(
f"Test {test.id} on {md_base} average pass ratio: {test_avg:.3f} ({repeat_passes}/{num_repeats} repeats passed). "
f"Example explanation: {explanations[0] if explanations else 'No explanation'}"
)
rule_type_breakdown[rule_type].append(rule_avg)
test_type_breakdown[test_type].append(test_avg)

overall_score = total_rule_score / len(all_rules) if all_rules else 0.0
return (overall_score, len(all_rules), candidate_errors, rule_failures, rule_type_breakdown)
overall_score = total_test_score / len(all_tests) if all_tests else 0.0
return (overall_score, len(all_tests), candidate_errors, test_failures, test_type_breakdown)


def main():
Expand Down Expand Up @@ -200,7 +141,7 @@ def main():
all_tests.extend(tests)

if not all_tests:
print("No valid rules found. Exiting.", file=sys.stderr)
print("No valid tests found. Exiting.", file=sys.stderr)
sys.exit(1)

# Identify candidate pipeline folders (subdirectories of input_folder excluding /pdfs)
Expand All @@ -216,37 +157,39 @@ def main():

# Evaluate each candidate
summary = []
print("\nRunning rules for each candidate:")
print("\nRunning tests for each candidate:")
for candidate in candidate_folders:
candidate_name = os.path.basename(candidate)
overall_score, total_rules, candidate_errors, rule_failures, rule_type_breakdown = evaluate_candidate(candidate, all_tests, pdf_basenames)
summary.append((candidate_name, overall_score, total_rules, candidate_errors, rule_failures, rule_type_breakdown))
overall_score, total_tests, candidate_errors, test_failures, test_type_breakdown = evaluate_candidate(
candidate, all_tests, pdf_basenames
)
summary.append((candidate_name, overall_score, total_tests, candidate_errors, test_failures, test_type_breakdown))
print(f"\nCandidate: {candidate_name}")
if candidate_errors:
for err in candidate_errors:
print(f" [ERROR] {err}")
else:
if rule_failures:
for fail in rule_failures:
if test_failures:
for fail in test_failures:
print(f" [FAIL] {fail}")
print(f" Average Score: {overall_score * 100:.1f}% over {total_rules} rules.")
print(f" Average Score: {overall_score * 100:.1f}% over {total_tests} tests.")

# Print final summary with breakdown by rule type
# Print final summary with breakdown by test type
print("\n" + "=" * 50)
print("Final Summary:")
for candidate_name, overall_score, total_rules, candidate_errors, _, rule_type_breakdown in summary:
for candidate_name, overall_score, total_tests, candidate_errors, _, test_type_breakdown in summary:
if candidate_errors:
status = "FAILED (errors)"
else:
status = f"{overall_score * 100:0.1f}%"
print(f"{candidate_name:20s} : Average Score: {overall_score * 100:0.1f}% over {total_rules:3d} rules - {status}")
print(" Breakdown by rule type:")
for rtype, scores in rule_type_breakdown.items():
print(f"{candidate_name:20s} : Average Score: {overall_score * 100:0.1f}% over {total_tests:3d} tests - {status}")
print(" Breakdown by test type:")
for ttype, scores in test_type_breakdown.items():
if scores:
avg = sum(scores) / len(scores) * 100
else:
avg = 0.0
print(f" {rtype:8s}: {avg:0.1f}% average pass rate over {len(scores)} rules")
print(f" {ttype:8s}: {avg:0.1f}% average pass rate over {len(scores)} tests")
print("=" * 50)


Expand Down
61 changes: 53 additions & 8 deletions olmocr/bench/datatypes.py → olmocr/bench/tests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass, field
from typing import List, Optional, Union, ClassVar, Type
from dataclasses import dataclass
from typing import Tuple
import json
import re
from enum import Enum

from fuzzysearch import find_near_matches
from rapidfuzz import fuzz


class TestType(str, Enum):
PRESENT = "present"
Expand Down Expand Up @@ -38,6 +40,14 @@ def __post_init__(self):
# Check that type is valid
if self.type not in [t.value for t in TestType]:
raise ValidationError(f"Invalid test type: {self.type}")

def run(self, md_content: str) -> Tuple[bool, str]:
"""
Run the test on the content of the provided .md file.
Returns a tuple (passed, explanation) where 'passed' is True if the test passes,
and 'explanation' is a short message explaining the failure when the test does not pass.
"""
raise NotImplementedError("Subclasses must implement run method")


@dataclass
Expand All @@ -54,6 +64,22 @@ def __post_init__(self):

if not self.text.strip():
raise ValidationError("Text field cannot be empty")

def run(self, md_content: str) -> Tuple[bool, str]:
reference_query = self.text
threshold = self.threshold
best_ratio = fuzz.partial_ratio(reference_query, md_content) / 100.0

if self.type == TestType.PRESENT.value:
if best_ratio >= threshold:
return (True, "")
else:
return (False, f"Expected '{reference_query[:40]}...' with threshold {threshold} but best match ratio was {best_ratio:.3f}")
else: # absent
if best_ratio < threshold:
return (True, "")
else:
return (False, f"Expected absence of '{reference_query[:40]}...' with threshold {threshold} but best match ratio was {best_ratio:.3f}")


@dataclass
Expand All @@ -74,7 +100,27 @@ def __post_init__(self):

if not self.after.strip():
raise ValidationError("After field cannot be empty")


def run(self, md_content: str) -> Tuple[bool, str]:
before = self.before
after = self.after
threshold = self.threshold
max_l_dist = round((1.0 - threshold) * len(before))

before_matches = find_near_matches(before, md_content, max_l_dist=max_l_dist)
after_matches = find_near_matches(after, md_content, max_l_dist=max_l_dist)

if not before_matches:
return (False, f"'before' search text '{before[:40]}...' not found with max_l_dist {max_l_dist}")
if not after_matches:
return (False, f"'after' search text '{after[:40]}...' not found with max_l_dist {max_l_dist}")

for before_match in before_matches:
for after_match in after_matches:
if before_match.start < after_match.start:
return (True, "")

return (False, f"Could not find a location where '{before[:40]}...' appears before '{after[:40]}...'.")


def load_tests(jsonl_file: str) -> list[BasePDFTest]:
Expand Down Expand Up @@ -125,10 +171,11 @@ def load_tests(jsonl_file: str) -> list[BasePDFTest]:

return tests


def save_tests(tests: list[BasePDFTest], jsonl_file: str) -> None:
"""Save tests to a JSONL file"""
with open(jsonl_file, 'w') as file:
for test in rules:
for test in tests:
# Convert dataclass to dict
if isinstance(test, TextPresenceTest):
data = {
Expand All @@ -147,6 +194,4 @@ def save_tests(tests: list[BasePDFTest], jsonl_file: str) -> None:
"before": test.before,
"after": test.after
}
file.write(json.dumps(data) + '\n')


file.write(json.dumps(data) + '\n')

0 comments on commit ae7efd3

Please sign in to comment.