Skip to content

Commit

Permalink
added asnwer_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqudwns committed Jan 4, 2024
1 parent 1bf367e commit f805623
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
18 changes: 10 additions & 8 deletions foarcle/actions/o2actions.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ cpdef act(
np.ndarray[np.uint8_t, ndim=2] inp,
tuple[int, int] inp_dim,
np.ndarray[np.uint8_t, ndim=2] answer,
tuple[int, int] answer_dim,
np.ndarray[np.uint8_t, ndim=2] grid,
tuple[int, int] grid_dim,
np.ndarray[np.npy_bool, ndim=2, cast=True] selected,
Expand Down Expand Up @@ -86,8 +87,8 @@ cpdef act(
elif operation == 34:
if trials_remain > 0:
trials_remain -= 1
if grid_dim[0] == answer.shape[0] and grid_dim[1] == answer.shape[1] and np.all(
grid[:grid_dim[0], :grid_dim[1]] == answer):
if grid_dim[0] == answer_dim[0] and grid_dim[1] == answer_dim[1] and np.all(
grid[:grid_dim[0], :grid_dim[1]] == answer[:grid_dim[0], :grid_dim[1]]):
if trials_remain > 0:
terminated = 1
reward = 1
Expand All @@ -100,7 +101,7 @@ cpdef act(
active, object_, object_sel, object_dim, object_pos, background, rotation_parity, reward)

cpdef batch_act(
b_inp, b_inp_dim, b_answer,
b_inp, b_inp_dim, b_answer, b_answer_dim,
b_grid, b_grid_dim, b_selected, b_clip, b_clip_dim, b_terminated, b_trials_remain,
b_active, b_object_, b_object_sel, b_object_dim, b_object_pos, b_background, b_rotation_parity,
b_selection, b_operation):
Expand All @@ -119,25 +120,26 @@ cpdef batch_act(
nb_object_pos = b_object_pos.copy()
nb_background = b_background.copy()
nb_rotation_parity = b_rotation_parity.copy()
reward = b_active.copy()

for i, (
inp, inp_dim, answer, grid, grid_dim, selected, clip, clip_dim, terminated, trials_remain,
inp, inp_dim, answer, answer_dim, grid, grid_dim, selected, clip, clip_dim, terminated, trials_remain,
active, object_, object_sel, object_dim, object_pos, background, rotation_parity,
selection, operation
) in enumerate(zip(
b_inp, b_inp_dim, b_answer, b_grid, b_grid_dim, b_selected, b_clip, b_clip_dim, b_terminated, b_trials_remain,
b_inp, b_inp_dim, b_answer, b_answer_dim, b_grid, b_grid_dim, b_selected, b_clip, b_clip_dim, b_terminated, b_trials_remain,
b_active, b_object_, b_object_sel, b_object_dim, b_object_pos, b_background, b_rotation_parity,
b_selection, b_operation)):

(nb_grid[i], nb_grid_dim[i], nb_selected[i], nb_clip[i], nb_clip_dim[i], nb_terminated[i], nb_trials_remain[i],
nb_active[i], nb_object_[i], nb_object_sel[i], nb_object_dim[i], nb_object_pos[i],
nb_background[i], nb_rotation_parity[i]) = act(
inp, inp_dim, answer, grid, grid_dim, selected, clip, clip_dim, terminated, trials_remain,
nb_background[i], nb_rotation_parity[i], reward[i]) = act(
inp, inp_dim, answer, answer_dim, grid, grid_dim, selected, clip, clip_dim, terminated, trials_remain,
active, object_, object_sel, object_dim, object_pos, background, rotation_parity,
selection, operation)

return (
nb_grid, nb_grid_dim, nb_selected, nb_clip, nb_clip_dim, nb_terminated, nb_trials_remain,
nb_active, nb_object_, nb_object_sel, nb_object_dim, nb_object_pos,
nb_background, nb_rotation_parity
nb_background, nb_rotation_parity, reward
)
9 changes: 7 additions & 2 deletions foarcle/foo2arcenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self,
data_loader,
max_grid_size: Tuple[SupportsInt, SupportsInt],
colors: SupportsInt,
max_trial: SupportsInt=-1):
max_trial: SupportsInt=3):
self.loader = data_loader
self.H, self.W = max_grid_size
self.colors = colors
Expand Down Expand Up @@ -64,8 +64,12 @@ def get_problem(self, options={}):
input_dim = inp[subprob_index].shape
inp = np.pad(
inp[subprob_index], [(0, self.H - input_dim[0]), (0, self.W - input_dim[1])], constant_values=0).astype(np.uint8)

answer_dim = ans[subprob_index].shape
answer = np.pad(
ans[subprob_index], [(0, self.H - answer_dim[0]), (0, self.W - answer_dim[1])], constant_values=0).astype(np.uint8)

return {'input': inp, 'input_dim': input_dim, 'answer': ans[subprob_index]}
return {'input': inp, 'input_dim': input_dim, 'answer': answer, 'answer_dim': answer_dim}

def reset(self, seed=None, options={}):
super().reset(seed=seed, options=options)
Expand Down Expand Up @@ -118,6 +122,7 @@ def _step(self, state: ObsType, action: ActType, info: Dict):
info['input'],
info['input_dim'],
info['answer'],
info['answer_dim'],
state['grid'],
state['grid_dim'],
state['selected'],
Expand Down

0 comments on commit f805623

Please sign in to comment.