Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implementing STaR: Self Taught Reasoner #1478

Merged
merged 38 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8bf5a02
implementing STaR: Self Taught Reasoner
GitHoobar Jan 21, 2025
245d023
minor fixes
GitHoobar Jan 23, 2025
fa8406f
minor fixes
GitHoobar Jan 24, 2025
a1875e1
Merge branch 'master' into feat/star-datagen
Wendong-Fan Jan 26, 2025
4985ec4
Merge branch 'master' into feat/star-datagen
Wendong-Fan Jan 30, 2025
44bea0d
enhance: STaR Integration (#1514)
Wendong-Fan Jan 30, 2025
1a5bd8e
minor fixes
GitHoobar Jan 30, 2025
a78d517
bug fix
GitHoobar Jan 30, 2025
f46e3c6
fix
GitHoobar Jan 30, 2025
7169cea
update
Wendong-Fan Jan 30, 2025
0634891
update with example math 500 data
Wendong-Fan Jan 30, 2025
047d4ab
add aime24
Wendong-Fan Jan 30, 2025
88b8bfc
update
Wendong-Fan Jan 30, 2025
6fbc4b8
update aime24 and amc23
Wendong-Fan Jan 30, 2025
a9fffe1
add gaokao2023 and gsm8k
Wendong-Fan Jan 30, 2025
f47118a
update
Wendong-Fan Jan 30, 2025
b24f710
add dynamic writing to the generated data file
Wendong-Fan Jan 30, 2025
9dd9d23
pre commit and test fix
Wendong-Fan Jan 30, 2025
e9fe5f8
uddate data gen example
Wendong-Fan Jan 30, 2025
ac7d62f
update data set
Wendong-Fan Jan 31, 2025
b46b4b9
add amc aime data
Wendong-Fan Jan 31, 2025
8d20395
bug: fixed broken json issue by adding a lock
AveryYay Jan 31, 2025
220c3db
pre-commit fix
Jan 31, 2025
c4382e5
update pipleline and ouput
Wendong-Fan Jan 31, 2025
c5b11a2
update example
Wendong-Fan Jan 31, 2025
c151e1c
update
Wendong-Fan Jan 31, 2025
17bfab0
update pipeline
WHALEEYE Jan 31, 2025
4f4e473
part3 output
AveryYay Jan 31, 2025
de9fd43
Merge remote-tracking branch 'origin/feat/star-datagen' into feat/sta…
AveryYay Jan 31, 2025
4210cb0
update
Wendong-Fan Feb 2, 2025
e358955
update
Wendong-Fan Feb 4, 2025
b14ce0c
update dateset downloading
Wendong-Fan Feb 4, 2025
21067bf
update pre commit and pytest
Wendong-Fan Feb 4, 2025
5ae99fb
PR clean
Wendong-Fan Feb 4, 2025
f028e39
update
Wendong-Fan Feb 4, 2025
79c5504
Merge branch 'master' into feat/star-datagen
Wendong-Fan Feb 7, 2025
79c08f9
update naming
Wendong-Fan Feb 7, 2025
3232a77
add deep dive blog
Wendong-Fan Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions camel/datagen/star/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

from .star_pipeline import STaRPipeline

__all__ = ['STaRPipeline']
305 changes: 305 additions & 0 deletions camel/datagen/star/star_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import json
from typing import Any, Dict, List, Optional

from pydantic import BaseModel

from camel.agents import ChatAgent
from camel.models.reward import BaseRewardModel, Evaluator


class TraceEvaluation(BaseModel):
correctness: float
clarity: float
completeness: float
feedback: str


class TraceIteration(BaseModel):
iteration: int
trace: str
evaluation: TraceEvaluation


class ProblemResult(BaseModel):
problem: str
final_trace: str
improvement_history: List[TraceIteration]


class STaRPipeline:
r"""Pipeline for generating self-taught reasoning traces
using the STaR methodology.

This implements the STaR paper's approach of:
1. Initial reasoning trace generation
2. Self-evaluation
3. Feedback-based improvement
4. Iterative refinement

Args:
agent (ChatAgent): The chat agent used for generating and improving
reasoning traces.
problems_path (str): Path to JSON file containing reasoning problems.
GitHoobar marked this conversation as resolved.
Show resolved Hide resolved
output_path (str, optional): Output path for saving traces.
(default: :obj:`'./star_output.json'`)
max_iterations (int, optional): Max iterations.
(default: :obj:`3`)
score_threshold (float, optional): Threshold to stop iterations.
(default: :obj:`0.7`)
reward_model (BaseRewardModel, optional): Model used for evaluating
reasoning traces. If None, uses LLM self-evaluation.
(default: :obj:`None`)
GitHoobar marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
agent: ChatAgent,
problems_path: str,
output_path: Optional[str] = './star_output.json',
max_iterations: int = 3,
score_threshold: float = 0.7,
reward_model: Optional[BaseRewardModel] = None,
):
r"""Initialize the STaR pipeline.

Args:
agent (ChatAgent): The chat agent used for generating and improving
reasoning traces.
problems_path (str): Path to problems JSON file.
output_path (str, optional): Output path for saving traces.
(default: :obj:`'./star_output.json'`)
max_iterations (int, optional): Max Iterations
(default: :obj:`3`)
score_threshold (float, optional): Quality threshold.
(default: :obj:`0.7`)
reward_model (BaseRewardModel, optional): Model used to evaluate
reasoning traces. If None, uses LLM self-evaluation.
(default: :obj:`None`)
"""
self.agent = agent
self.problems = self.load_problems(problems_path)
self.output_path = output_path
self.max_iterations = max_iterations
self.score_threshold = score_threshold
self.reward_model = reward_model
self.evaluator = (
Evaluator(reward_model=reward_model) if reward_model else None
)
self.reasoning_traces: List[Dict[str, Any]] = []

def load_problems(self, path: str) -> List[Dict]:
r"""Load reasoning problems from JSON file.

Args:
path (str): Path to the JSON file containing the problems.

Returns:
List[Dict]: List of problem dictionaries loaded from the file.
"""
with open(path, 'r') as f:
data = json.load(f)
return data['problems']

def generate_reasoning_trace(self, problem: str) -> str:
r"""Generate initial reasoning trace for a given problem.

Args:
problem (str): The problem text to generate reasoning for.

Returns:
str: Generated reasoning trace.
"""
self.agent.reset()
prompt = self.REASONING_TEMPLATE.format(problem=problem)
response = self.agent.step(prompt)
GitHoobar marked this conversation as resolved.
Show resolved Hide resolved
return response.msg.content

def evaluate_trace(self, problem: str, trace: str) -> Dict[str, Any]:
r"""Evaluate the quality of a reasoning trace.

Args:
problem (str): The original problem text to evaluate against.
trace (str): The reasoning trace to evaluate.

Returns:
TraceEvaluation: Evaluation results containing:
- correctness (float): Score for logical correctness
- clarity (float): Score for clarity of explanation
- completeness (float): Score for completeness of reasoning
- feedback (str): Detailed feedback for improvement
"""
self.agent.reset()
if self.evaluator:
# Use reward model evaluation
messages = [
{"role": "user", "content": problem},
{"role": "assistant", "content": trace},
]
scores = self.evaluator.evaluate(messages)
return {
"correctness": scores.get(
"correctness", scores.get("Score", 0)
)
/ 5.0,
"clarity": scores.get("coherence", scores.get("Score", 0))
/ 5.0,
"completeness": scores.get(
"helpfulness", scores.get("Score", 0)
)
/ 5.0,
"feedback": "Evaluation by reward model",
}
GitHoobar marked this conversation as resolved.
Show resolved Hide resolved

else:
# Fallback to original LLM self-evaluation
prompt = self.EVALUATION_TEMPLATE.format(
problem=problem, trace=trace
)
response = self.agent.step(prompt, response_format=TraceEvaluation)
if response.msg.parsed is None:
raise AttributeError("Failed to parse evaluation response")
# Convert dict to TraceEvaluation if needed
if isinstance(response.msg.parsed, dict):
evaluation = TraceEvaluation(**response.msg.parsed)
else:
evaluation = response.msg.parsed

return evaluation.model_dump()

def improve_trace(self, problem: str, trace: str, feedback: str) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original paper mentioned the term "rationalization," but I don't seem to see a similar implementation in this improvement method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct. The original STaR (Self-Taught Reasoner) paper uses rationalization. But, Our current implementation is different because it's a test-time method that directly generates reasoning without having access to ground truth solutions.
Focusing more on the data gen part at the moment.
cc: @Wendong-Fan

r"""Generate improved reasoning trace based on feedback.

Args:
problem (str): The original problem text.
trace (str): The current reasoning trace.
feedback (str): Feedback for improving the trace.

Returns:
str: Improved reasoning trace.
"""
self.agent.reset()
prompt = self.IMPROVEMENT_TEMPLATE.format(
problem=problem, trace=trace, feedback=feedback
)
response = self.agent.step(prompt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChatAgent in camel is stateful, we need to use self.agent.reset() to clear the memory for the next step

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolve the conversation if you have updated

return response.msg.content

def process_problem(self, problem: Dict) -> Dict[str, Any]:
r"""Process a single problem through the STaR pipeline.

Args:
problem (Dict): Problem dictionary containing the problem text.

Returns:
ProblemResult: Results with final trace and history.
"""
problem_text = problem['problem']
current_trace = self.generate_reasoning_trace(problem_text)
traces = []

for iteration in range(self.max_iterations):
# Evaluate current trace
eval_dict = self.evaluate_trace(problem_text, current_trace)
evaluation = TraceEvaluation(**eval_dict)

# Check if quality threshold met
avg_score = (
evaluation.correctness
+ evaluation.clarity
+ evaluation.completeness
) / 3

traces.append(
TraceIteration(
iteration=iteration,
trace=current_trace,
evaluation=evaluation,
)
)

if avg_score >= self.score_threshold:
break
GitHoobar marked this conversation as resolved.
Show resolved Hide resolved

# Generate improved trace
if iteration < self.max_iterations - 1:
current_trace = self.improve_trace(
problem_text, current_trace, evaluation.feedback
)

result = ProblemResult(
problem=problem_text,
final_trace=current_trace,
improvement_history=traces,
)

return result.model_dump()

def generate(self):
r"""Execute the STaR pipeline on all problems.

Process problems and save results.
"""
for problem in self.problems:
result = self.process_problem(problem)
self.reasoning_traces.append(result)

if self.output_path:
with open(self.output_path, 'w') as f:
json.dump(self.reasoning_traces, f, indent=2)

# Templates for generating reasoning, evaluation and improving them.
REASONING_TEMPLATE = """Let's solve this step by step:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When constructing the reasoning prompt, consider adding few-shot examples, as this can improve the performance to some extent. The original paper also adopts this approach.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has been taken care of

Problem: {problem}
1. First, let's understand what we're asked
2. Let's break this down into parts
3. Let's solve each part systematically
4. Finally, let's verify our solution

Please show your complete reasoning process."""

EVALUATION_TEMPLATE = """Please evaluate this reasoning trace and
provide scores and feedback in valid JSON format.

Problem: {problem}

Reasoning Trace:
{trace}

Evaluate for:
1. Correctness (Is each step logically sound?)
2. Clarity (Is the explanation clear and well-structured?)
3. Completeness (Are all necessary steps included?)

Respond ONLY with a JSON object in this exact format:
{{
"correctness": <score between 0 and 1>,
"clarity": <score between 0 and 1>,
"completeness": <score between 0 and 1>,
"feedback": "<specific feedback for improvement>"
}}"""

IMPROVEMENT_TEMPLATE = """Based on this feedback, generate an
improved reasoning trace:
Problem: {problem}

Previous Trace:
{trace}

Feedback:
{feedback}

Generate a new, improved reasoning trace that addresses the feedback."""
28 changes: 28 additions & 0 deletions examples/star_datagen/input_problems.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"problems": [
{
"id": "problem_0",
"problem": "Prove that for any positive integer n, if n² is even, then n is even. Show all steps and explain why each step is valid.",
"type": "mathematical_proof",
"solution": "n is even"
},
{
"id": "problem_1",
"problem": "A cylindrical water tank has a radius of 3 meters and a height of 4 meters. If water is flowing into the tank at a rate of 2 cubic meters per minute, and simultaneously draining at a rate of 1 cubic meter per minute, how long will it take for the tank to fill up completely? Use π = 3.14 and show all calculations.",
"type": "word_problem",
"solution": "37.68 minutes"
},
{
"id": "problem_2",
"problem": "Find all values of x that satisfy the equation: |x - 2| + |x + 1| = 7. Show your reasoning and verify all solutions.",
"type": "algebra",
"solution": "x = -4 or x = 4"
},
{
"id": "problem_3",
"problem": "In how many ways can 8 different books be arranged on a shelf if 3 specific books must always be kept together (but can be arranged in any order among themselves)? Explain your approach and verify the answer.",
"type": "combinatorics",
"solution": "720"
}
]
}
Loading
Loading