diff --git a/foarcle/actions/o2actions.pyx b/foarcle/actions/o2actions.pyx index 5b4838c..da44e29 100644 --- a/foarcle/actions/o2actions.pyx +++ b/foarcle/actions/o2actions.pyx @@ -11,6 +11,7 @@ cpdef act( np.ndarray[np.uint8_t, ndim=2] inp, tuple[int, int] inp_dim, np.ndarray[np.uint8_t, ndim=2] answer, + tuple[int, int] answer_dim, np.ndarray[np.uint8_t, ndim=2] grid, tuple[int, int] grid_dim, np.ndarray[np.npy_bool, ndim=2, cast=True] selected, @@ -86,8 +87,8 @@ cpdef act( elif operation == 34: if trials_remain > 0: trials_remain -= 1 - if grid_dim[0] == answer.shape[0] and grid_dim[1] == answer.shape[1] and np.all( - grid[:grid_dim[0], :grid_dim[1]] == answer): + if grid_dim[0] == answer_dim[0] and grid_dim[1] == answer_dim[1] and np.all( + grid[:grid_dim[0], :grid_dim[1]] == answer[:grid_dim[0], :grid_dim[1]]): if trials_remain > 0: terminated = 1 reward = 1 @@ -100,7 +101,7 @@ cpdef act( active, object_, object_sel, object_dim, object_pos, background, rotation_parity, reward) cpdef batch_act( - b_inp, b_inp_dim, b_answer, + b_inp, b_inp_dim, b_answer, b_answer_dim, b_grid, b_grid_dim, b_selected, b_clip, b_clip_dim, b_terminated, b_trials_remain, b_active, b_object_, b_object_sel, b_object_dim, b_object_pos, b_background, b_rotation_parity, b_selection, b_operation): @@ -119,25 +120,26 @@ cpdef batch_act( nb_object_pos = b_object_pos.copy() nb_background = b_background.copy() nb_rotation_parity = b_rotation_parity.copy() + reward = b_active.copy() for i, ( - inp, inp_dim, answer, grid, grid_dim, selected, clip, clip_dim, terminated, trials_remain, + inp, inp_dim, answer, answer_dim, grid, grid_dim, selected, clip, clip_dim, terminated, trials_remain, active, object_, object_sel, object_dim, object_pos, background, rotation_parity, selection, operation ) in enumerate(zip( - b_inp, b_inp_dim, b_answer, b_grid, b_grid_dim, b_selected, b_clip, b_clip_dim, b_terminated, b_trials_remain, + b_inp, b_inp_dim, b_answer, b_answer_dim, b_grid, b_grid_dim, b_selected, b_clip, b_clip_dim, b_terminated, b_trials_remain, b_active, b_object_, b_object_sel, b_object_dim, b_object_pos, b_background, b_rotation_parity, b_selection, b_operation)): (nb_grid[i], nb_grid_dim[i], nb_selected[i], nb_clip[i], nb_clip_dim[i], nb_terminated[i], nb_trials_remain[i], nb_active[i], nb_object_[i], nb_object_sel[i], nb_object_dim[i], nb_object_pos[i], - nb_background[i], nb_rotation_parity[i]) = act( - inp, inp_dim, answer, grid, grid_dim, selected, clip, clip_dim, terminated, trials_remain, + nb_background[i], nb_rotation_parity[i], reward[i]) = act( + inp, inp_dim, answer, answer_dim, grid, grid_dim, selected, clip, clip_dim, terminated, trials_remain, active, object_, object_sel, object_dim, object_pos, background, rotation_parity, selection, operation) return ( nb_grid, nb_grid_dim, nb_selected, nb_clip, nb_clip_dim, nb_terminated, nb_trials_remain, nb_active, nb_object_, nb_object_sel, nb_object_dim, nb_object_pos, - nb_background, nb_rotation_parity + nb_background, nb_rotation_parity, reward ) diff --git a/foarcle/foo2arcenv.py b/foarcle/foo2arcenv.py index 9986f25..c177640 100644 --- a/foarcle/foo2arcenv.py +++ b/foarcle/foo2arcenv.py @@ -13,7 +13,7 @@ def __init__(self, data_loader, max_grid_size: Tuple[SupportsInt, SupportsInt], colors: SupportsInt, - max_trial: SupportsInt=-1): + max_trial: SupportsInt=3): self.loader = data_loader self.H, self.W = max_grid_size self.colors = colors @@ -64,8 +64,12 @@ def get_problem(self, options={}): input_dim = inp[subprob_index].shape inp = np.pad( inp[subprob_index], [(0, self.H - input_dim[0]), (0, self.W - input_dim[1])], constant_values=0).astype(np.uint8) + + answer_dim = ans[subprob_index].shape + answer = np.pad( + ans[subprob_index], [(0, self.H - answer_dim[0]), (0, self.W - answer_dim[1])], constant_values=0).astype(np.uint8) - return {'input': inp, 'input_dim': input_dim, 'answer': ans[subprob_index]} + return {'input': inp, 'input_dim': input_dim, 'answer': answer, 'answer_dim': answer_dim} def reset(self, seed=None, options={}): super().reset(seed=seed, options=options) @@ -118,6 +122,7 @@ def _step(self, state: ObsType, action: ActType, info: Dict): info['input'], info['input_dim'], info['answer'], + info['answer_dim'], state['grid'], state['grid_dim'], state['selected'],