diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index 23bc1065..172d0ca4 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -206,6 +206,10 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]): # dim=-1, # ) + # If all trajectories were filtered, stop there. + if not len(training_objects): + return + if self.cutoff_distance >= 0: # Filter the batch for diverse final_states with high reward. batch = training_objects.last_states.tensor.float()