diff --git a/agent_runner.py b/agent_runner.py index da5f2f7..791dcc1 100644 --- a/agent_runner.py +++ b/agent_runner.py @@ -114,6 +114,33 @@ def run_agent_eval( results: List[Tuple[int, Dict[str, Any], str]] = [] tasks = list(range(len(dataset))) + tasks_to_run = tasks + if reuse: + tasks_to_run = [] + for idx in tasks: + if do_eval: + eval_cached = store.load_eval(idx) + if eval_cached is not None: + cached_score = eval_cached.get("score", eval_cached) + cached_final = eval_cached.get("final_answer", "") + if not cached_final: + traj = store.load_traj(idx) + if traj is not None: + cached_final = traj.get("final_answer", "") + results.append((idx, cached_score, cached_final)) + continue + tasks_to_run.append(idx) + continue + + if do_infer: + traj = store.load_traj(idx) + if traj and traj.get("success"): + results.append((idx, {}, traj.get("final_answer", ""))) + else: + tasks_to_run.append(idx) + else: + tasks_to_run.append(idx) + if nproc > 1: with ThreadPoolExecutor(max_workers=nproc) as executor: futures = [ @@ -128,15 +155,15 @@ def run_agent_eval( do_infer, do_eval, ) - for idx in tasks + for idx in tasks_to_run ] - with tqdm(total=len(tasks), desc="Agent Eval", unit="sample") as pbar: + with tqdm(total=len(tasks_to_run), desc="Agent Eval", unit="sample") as pbar: for fut in as_completed(futures): results.append(fut.result()) pbar.update(1) else: - with tqdm(total=len(tasks), desc="Agent Eval", unit="sample") as pbar: - for idx in tasks: + with tqdm(total=len(tasks_to_run), desc="Agent Eval", unit="sample") as pbar: + for idx in tasks_to_run: results.append( _run_one_sample( idx, agent, dataset, store, judge_kwargs, reuse, do_infer, do_eval diff --git a/scieval/agents/smolagents.py b/scieval/agents/smolagents.py index 32ce090..707c104 100644 --- a/scieval/agents/smolagents.py +++ b/scieval/agents/smolagents.py @@ -104,7 +104,7 @@ def __init__( ): super().__init__(name=self.name, model_version=model_version) self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "") - self.api_base = api_base or os.environ.get("OPENAI_BASE_URL", "") + self.api_base = api_base or os.environ.get("OPENAI_API_BASE", "") self.model_version = model_version or os.environ.get("MODEL_ID", "o3") def run(self, sample: EvalSample) -> EvalResult: