-
Notifications
You must be signed in to change notification settings - Fork 344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Async mode in PPO/GRPO #580
Comments
Also, the else condition should have: |
Thanks for looking into it closely! You are right. The correct code should be import queue
import threading
import time
class Agent():
def __init__(self):
self.param = 1
def learn(self, data):
self.param += 1
def query_generator_fn():
for i in range(1, 100):
yield i
ITER = 7
batch_size = 32
agent = Agent()
data_Q = queue.Queue(maxsize=1)
param_and_query_Q = queue.Queue(maxsize=1)
def actor():
for i in range(1, ITER + 1):
params, query = param_and_query_Q.get()
data = params
print(f"[actor] generating data π_{params} -> p_{query} D_π_{data}")
time.sleep(1) # simulate data generation
data_Q.put((query, data))
actor_thread = threading.Thread(target=actor)
actor_thread.start()
# initial param put
generator = query_generator_fn()
next_queries = next(generator)
param_and_query_Q.put((agent.param, next_queries))
# cleanba style stuff
async_mode = True
start_time = time.time()
for g in range(1, ITER + 1):
queries = next_queries
if async_mode:
if g != 1:
next_queries = next(generator)
param_and_query_Q.put((agent.param, next_queries))
else:
if g != 1:
next_queries = next(generator)
param_and_query_Q.put((agent.param, next_queries)) # note the indent here is different
queries = next_queries
_, data = data_Q.get()
old_param = agent.param
agent.learn(data)
time.sleep(1) # simulate training
print(f"--[leaner] get π_{old_param} -> p_{queries} D_π_{data} -> π_{agent.param}, time: {time.time() - start_time}")
actor_thread.join() Async mode:Existing incorrect code
correct (because
Sync mode:Existing incorrect code (
correct (because
|
Hello,
This is not regarding the issue in the code, but this readme: https://github.com/allenai/open-instruct/blob/main/docs/algorithms/ppo.md
Here, in the code at the last, inside the if async_mode condition:
I think
param_and_query_Q.put((agent.param, queries))
should be replaced by
param_and_query_Q.put((agent.param, next_queries))
Please let me know if that is not the case. Thanks!
The text was updated successfully, but these errors were encountered: