Skip to content

Commit

Permalink
add citations and correct max_len with int
Browse files Browse the repository at this point in the history
  • Loading branch information
siyuyuan committed Jan 25, 2025
1 parent 2da6bf8 commit 3998653
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ export MODEL_NAME=YOUR_MODEL_NAME
export MODEL_DIR=YOUR_MODEL_DIR
export TASK=YOUR_TASK
export TEMP=YOUR_TEMP
export MAX_TOKEN_LENGTH=YOUR_MAX_TOKEN_LENGTH # Replace with the maximum token length for the model
export ALPHA=YOUR_ALPHA # the lower bound for high-quality trajectories
export BETA=YOUR_BETA # The distinguishable gap

Expand Down Expand Up @@ -203,4 +204,10 @@ Their contributions to the open-source community have been invaluable and greatl
## Citation
If you use this code in your research, please cite:
```bibtex
```# Agent-R
@article{yuan2025agent,
title={Agent-R: Training Language Model Agents to Reflect via Iterative Self-Training},
author={Yuan, Siyu and Chen, Zehui and Xi, Zhiheng and Ye, Junjie and Du, Zhengyin and Chen, Jiecao},
journal={arXiv preprint arXiv:2501.11425},
year={2025}
}
```
2 changes: 1 addition & 1 deletion mcts_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def setup_conversation(env):
def perform_mcts_search(Task, calling, env, conv, model_name, idx):

recent_actions = []
mcts_search = ExtendedMCTS(calling=calling, max_len=os.environ["MAX_TOKEN_LENGTH"], model_name=model_name, env=env, idx=idx)
mcts_search = ExtendedMCTS(calling=calling, max_len=int(os.environ["MAX_TOKEN_LENGTH"]), model_name=model_name, env=env, idx=idx)

mcts_search.search(env, conv, recent_actions)
dir_path = f"mcts_result/{Task}/{model_name}"
Expand Down
2 changes: 1 addition & 1 deletion path_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def revise_worst_path(calling, worst_path, best_path, task_description):
for node in worst_path[1:]:
# Generate the prompt for the verifier
action_obs_prompt = '\n'.join(action_obs)
max_len = 7600
max_len = int(os.environ["MAX_TOKEN_LENGTH"])
while len(action_obs_prompt.split()) > max_len - 60:
action_obs_prompt = action_obs_prompt[6:] # Truncate prompt if too long

Expand Down

0 comments on commit 3998653

Please sign in to comment.