Why not making ReplayBuffers serializabe? #17
Replies: 5 comments 8 replies
-
@jamartinh I am not sure about usage of ray, however, If you give us a sample code, OS, and versions of libraries, we can investigate it. FYI: Here are the internal implementations for pickling. cpprb/cpprb/PyReplayBuffer.pyx Lines 463 to 464 in d513587 cpprb/cpprb/PyReplayBuffer.pyx Lines 2138 to 2142 in d513587 |
Beta Was this translation helpful? Give feedback.
-
@jamartinh from multiprocessing.shared_memory import SharedMemory
import time
from cpprb import PrioritizedReplayBuffer
import numpy as np
import ray
@ray.remote
def worker(rb, v, shm):
"""
Worker
Notes
-----
We must pass SharedMemory directly. When we pass np.ndarray,
the array is copied and doesn't point to shared memory any more.
"""
done = np.ndarray(shape=tuple(), dtype=np.int32, buffer=shm.buf)
while not done:
print(done)
rb.add.remote(a=v)
time.sleep(1)
return None
@ray.remote
class RemoteReplayBuffer:
"""
Wrapper for Replay Buffer
Notes
-----
1. All method calls are executed serially.
2. We cannot pass the buffer class directly.
>>> rb = ray.remote(PrioritizedReplayBuffer).remote(buffer_size, env_dict)
TypeError: __cinit__() takes at least 1 positional argument (0 given)
"""
def __init__(self, *args, **kwargs):
self.rb = PrioritizedReplayBuffer(*args, **kwargs)
def add(self, **kwargs):
return self.rb.add(**kwargs)
def sample(self, *args, **kwargs):
return self.rb.sample(*args, **kwargs)
def update_priorities(self, *args, **kwargs):
return self.rb.update_priorities(*args, **kwargs)
def get_stored_size(self):
return self.rb.get_stored_size()
def get_all_transitions(self):
return self.rb.get_all_transitions()
def run():
buffer_size = 32
env_dict = {"a": {}}
alpha = 0.5
shm = SharedMemory(create=True, size=32)
try:
done = np.ndarray(shape=tuple(), dtype=np.int32, buffer=shm.buf)
done[...] = 0
ray.init()
rb = RemoteReplayBuffer.remote(buffer_size, env_dict, alpha=alpha)
w1 = worker.remote(rb, 1, shm)
w2 = worker.remote(rb, np.asarray([2, 3]), shm)
while True:
stored_size = ray.get(rb.get_stored_size.remote())
print(stored_size)
if stored_size < 20:
time.sleep(1)
else:
break
done[...] = 1
ray.get([w1, w2])
print(ray.get(rb.get_stored_size.remote()))
print(ray.get(rb.get_all_transitions.remote()))
finally:
# To avoid (shared) memory leak, we must close() and unlink().
# On Linux, you might find shared memory file at /dev/shm
shm.close()
shm.unlink()
if __name__ == "__main__":
run() |
Beta Was this translation helpful? Give feedback.
-
It is working for me now, for simple MPReplayBuffer with Ray. I have pushed to a draft pool request the code: #19 For this to work, you either have to put in every actor
and pass the global_replay_bufffer during a or do:
And then after that:
Before calling ray.init All the history resumes to make the |
Beta Was this translation helpful? Give feedback.
-
@jamartinh Based on your work, we finally released cpprb v10.6.0. After some design modification, these buffers take new construction parameters for context ( import base64
import multiprocessing as mp
from cpprb import MPReplayBuffer
import ray
ray.init()
encoded = base64.b64encode(mp.current_process().authkey)
def auth_fn(*args):
mp.current_process().authkey = base64.b64decode(encoded)
ray.worker.global_worker.run_function_on_all_workers(auth_fn)
buffer_size = 1e+6
m = mp.get_context().Manager()
# Use `SyncManager` as context, "SharedMemory" as backend
rb = MPReplayBuffer(buffer_size, {"done": {}}, ctx=m, backend="SharedMemory") Thank you for your great contribution. |
Beta Was this translation helpful? Give feedback.
-
I found Ray's (private) method We re-considered the usage. @ray.remote
class RemoteWorker:
# Encode base64 to avoid following error:
# TypeError: Pickling an AuthenticationString object is disallowed for security reasons
encoded = base64.b64encode(mp.current_process().authkey)
def __init__(self):
# Set up 'authkey' to communicate with `SyncManager`.
# Important: Do not pass `MPReplayBuffer` here, because it is not ready.
mp.current_process().authkey = base64.b64decode(self.encoded)
def run(self, rb):
pass
w = RemoteWorker.remote()
w.run.remote(rb) We updated the example, too. |
Beta Was this translation helpful? Give feedback.
-
It would be very good to make replay Buffers serializable.
For instance, how can I pass and share buffer using ray library?
I tried to use the MPreplaybuffer but ray says it is not serializabe.
Perhaps the serialization/deserialization of the MPrepayBuffer can be done easy as just passing the neccesary info to recreate a Buffer pointing to the shared memory.
Beta Was this translation helpful? Give feedback.
All reactions