diff --git a/src/search_file.py b/src/search_file.py index 5d5bf5840..f3efb6952 100755 --- a/src/search_file.py +++ b/src/search_file.py @@ -105,6 +105,9 @@ class ArgsMismatchException(Exception): class SourceChangedException(Exception): pass +class KilledException(Exception): + pass + class SearchResult(NamedTuple): status: SearchStatus @@ -231,6 +234,7 @@ def parse_arguments(args_list: List[str]) -> Tuple[argparse.Namespace, parser.add_argument("--max-term-length", type=int, default=256) parser.add_argument("--add-env-lemmas", type=Path2, default=None) parser.add_argument("--add-axioms", type=Path2, default=None) + parser.add_argument("--max-search-time-per-lemma", default=None, type=float) if __name__ == "__main__": known_args = parser.parse_args(args_list) else: @@ -396,6 +400,31 @@ def search_file_worker(args: argparse.Namespace, coq, worker_idx, predictor, predictor_lock) + except KilledException: + solution = [ + TacticInteraction("Proof.", initial_context), + TacticInteraction("Admitted.", initial_context) + ] + done.put(((next_file, coq.module_prefix, + next_lemma), + SearchResult(SearchStatus.INCOMPLETE, + solution))) + try: + next_job = jobs.get_nowait() + except queue.Empty: + return + new_file, next_module, next_lemma = next_job + if new_file != next_file: + next_file = new_file + with util.silent(): + all_commands = serapi_instance.\ + load_commands_preserve( + args, 0, + args.prelude / next_file) + rest_commands = all_commands + else: + rest_commands = all_commands + break except serapi_instance.CoqAnomaly: if args.hardfail: raise @@ -906,6 +935,8 @@ def get_lemma_declaration_from_name(coq: serapi_instance.SerapiInstance, # The core of the search report +import _thread +import threading # This method attempts to complete proofs using search. def attempt_search(args: argparse.Namespace, @@ -923,11 +954,18 @@ def attempt_search(args: argparse.Namespace, for lemma_name in f] else: env_lemmas = [] - result = dfs_proof_search_with_graph(lemma_statement, module_name, - env_lemmas, - coq, - args, bar_idx, predictor, - predictor_lock) + timer = threading.Timer(args.max_search_time_per_lemma, _thread.interrupt_main) + timer.start() + try: + result = dfs_proof_search_with_graph(lemma_statement, module_name, + env_lemmas, + coq, + args, bar_idx, predictor, + predictor_lock) + except: + raise KilledException("Lemma timeout") + finally: + timer.cancel() return result @@ -1149,13 +1187,12 @@ def search(pbar: tqdm, current_path: List[LabeledNode], coq.prev_tactics, coq.hypotheses, coq.goals) - with predictor_lock: - with util.silent(): - predictions = predictor.predictKTactics( - truncate_tactic_context(tactic_context_before, - args.max_term_length), - args.max_attempts) - assert len(predictions) == args.max_attempts + with util.silent(): + predictions = predictor.predictKTactics( + truncate_tactic_context(tactic_context_before, + args.max_term_length), + args.max_attempts) + assert len(predictions) == args.max_attempts proof_context_before = coq.proof_context if coq.use_hammer: predictions = [Prediction(prediction.prediction[:-1] + "; try hammer.",