Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aorwall committed Aug 14, 2024
1 parent e7706ab commit 37a8dbe
Show file tree
Hide file tree
Showing 5 changed files with 500 additions and 466 deletions.
31 changes: 28 additions & 3 deletions moatless/benchmark/swebench/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json
import logging
import os
from datetime import datetime, timezone
from typing import Optional

from datasets import load_dataset

from moatless.benchmark.utils import (
file_spans_to_dict,
get_missing_files,
get_missing_spans,
get_missing_files, get_missing_spans,
)
from moatless.file_context import FileContext
from moatless.index import CodeIndex
Expand All @@ -19,6 +20,8 @@
logger = logging.getLogger(__name__)


_moatless_instances = {}

def load_instances(
dataset_name: str = "princeton-nlp/SWE-bench_Lite", split: str = "test"
):
Expand All @@ -35,6 +38,21 @@ def load_instance(
return data[instance_id]


def load_moatless_dataset():
global _moatless_instances
with open("moatless/benchmark/swebench_lite_all_evaluations.json") as f:
dataset = json.load(f)
_moatless_instances = {d["instance_id"]: d for d in dataset}

def get_moatless_instance(
instance_id: str
):
global _moatless_instances
if not _moatless_instances:
load_moatless_dataset()
return _moatless_instances.get(instance_id)


def sorted_instances(
dataset_name: str = "princeton-nlp/SWE-bench_Lite",
split: str = "test",
Expand All @@ -56,6 +74,7 @@ def found_in_expected_spans(instance: dict, spans: dict):
logging.warning(
f"{instance['instance_id']} Expected spans for {file_path} is empty"
)

missing_spans = get_missing_spans(instance["expected_spans"], spans)
return not missing_spans

Expand Down Expand Up @@ -306,6 +325,7 @@ def create_workspace(
instance_id: Optional[str] = None,
repo_base_dir: Optional[str] = None,
index_store_dir: Optional[str] = None,
create_instance_dir: bool = False,
):
"""
Create a workspace for the given SWE-bench instance.
Expand All @@ -322,7 +342,12 @@ def create_workspace(

repo_dir_name = instance["repo"].replace("/", "__")
repo_url = f"https://github.com/swe-bench/{repo_dir_name}.git"
repo_dir = f"{repo_base_dir}/swe-bench_{repo_dir_name}"

if create_instance_dir:
date_str = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
repo_dir = f"{repo_base_dir}/swe-bench_{instance['instance_id']}_{date_str}"
else:
repo_dir = f"{repo_base_dir}/{repo_dir_name}"
repo = GitRepository.from_repo(
git_repo_url=repo_url, repo_path=repo_dir, commit=instance["base_commit"]
)
Expand Down
5 changes: 4 additions & 1 deletion moatless/benchmark/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import re
import time
from typing import List

from moatless.codeblocks.module import Module
from moatless.index.types import SearchCodeHit, CodeSnippet
from moatless.repository import FileRepository
from moatless.types import FileWithSpans

Expand Down Expand Up @@ -199,7 +201,8 @@ def get_missing_spans(


def calculate_estimated_context_window(instance, results):
patch_diffs = get_diff_lines(instance["patch"])
patch = instance.get("patch") or instance.get("golden_patch")
patch_diffs = get_diff_lines(patch)
expected_changes = []

for patch_diff in patch_diffs:
Expand Down
32 changes: 12 additions & 20 deletions moatless/repository/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,34 +27,32 @@ class UpdateResult:
class CodeFile(BaseModel):
file_path: str = Field(..., description="The path to the file")
content: str = Field(..., description="The content of the file")

_module: Module | None = PrivateAttr(None)
_dirty: bool = PrivateAttr(False)

@classmethod
def from_file(cls, repo_path: str, file_path: str):
with open(os.path.join(repo_path, file_path)) as f:
parser = get_parser_by_path(file_path)
if parser:
content = f.read()
module = parser.parse(content)
else:
module = None
return cls(file_path=file_path, content=content, module=module)
content = f.read()
return cls(file_path=file_path, content=content)

@classmethod
def from_content(cls, file_path: str, content: str):
parser = PythonParser()
module = parser.parse(content)
return cls(file_path=file_path, content=content, module=module)
return cls(file_path=file_path, content=content)

@property
def supports_codeblocks(self):
return self.module is not None

@property
def module(self) -> Module:
if not self._module:
return None
def module(self) -> Module | None:
if self._module is None:
parser = get_parser_by_path(self.file_path)
if parser:
self._module = parser.parse(self.content)
else:
return None
return self._module

@property
Expand Down Expand Up @@ -225,13 +223,7 @@ def get_file(
return None

with open(full_file_path) as f:
parser = get_parser_by_path(file_path)
if parser:
content = f.read()
module = parser.parse(content)
found_file = CodeFile(file_path=file_path, content=content, module=module)
else:
found_file = CodeFile(file_path=file_path, content=f.read())
found_file = CodeFile(file_path=file_path, content=f.read())

if not existing_file:
existing_file = found_file
Expand Down
Loading

0 comments on commit 37a8dbe

Please sign in to comment.