Skip to content

Commit

Permalink
Add a margin at the top of the window
Browse files Browse the repository at this point in the history
  • Loading branch information
tomekster committed Jan 27, 2024
1 parent 654faf8 commit 151f8f4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 18 deletions.
Binary file added mo_gymnasium/envs/fruit_tree/assets/node_blue.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
61 changes: 43 additions & 18 deletions mo_gymnasium/envs/fruit_tree/fruit_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ class FruitTreeEnv(gym.Env, EzPickle):
The episode terminates when the agent reaches a leaf node.
"""

metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

def __init__(self, depth=6, render_mode: Optional[str] = None):
assert depth in [5, 6, 7], "Depth must be 5, 6 or 7."
EzPickle.__init__(self, depth)
Expand All @@ -296,6 +298,8 @@ def __init__(self, depth=6, render_mode: Optional[str] = None):

# pygame
self.row_height = 20
self.top_margin = 15

# Add margin at the bottom to account for the node rewards
self.window_size = (1200, self.row_height * self.tree_depth + 150)
self.node_square_size = np.array([10, 10], dtype=np.int32)
Expand All @@ -305,6 +309,7 @@ def __init__(self, depth=6, render_mode: Optional[str] = None):
self.font = pygame.font.SysFont(None, self.font_size)

self.window = None
self.clock = None
self.node_img = None
self.agent_img = None

Expand Down Expand Up @@ -375,68 +380,88 @@ def render(self):
)
return

if self.clock is None and self.render_mode == "human":
self.clock = pygame.time.Clock()

if self.window is None:
pygame.init()

if self.render_mode == "human":
pygame.display.init()
pygame.display.set_caption("Fruit Tree")
self.window = pygame.display.set_mode(self.window_size)
self.clock.tick(self.metadata["render_fps"])
else:
self.window = pygame.Surface(self.window_size)

if self.node_img is None:
filename = path.join(path.dirname(__file__), "assets", "node.png")
filename = path.join(path.dirname(__file__), "assets", "node_blue.png")
self.node_img = pygame.transform.scale(pygame.image.load(filename), self.node_square_size)
self.node_img = pygame.transform.flip(self.node_img, flip_x=True, flip_y=False)

if self.agent_img is None:
filename = path.join(path.dirname(__file__), "assets", "agent.png")
self.agent_img = pygame.transform.scale(pygame.image.load(filename), self.node_square_size)

canvas = pygame.Surface(self.window_size)
canvas.fill((0, 0, 0))
canvas.fill((255, 255, 255)) # White

self.window.blit(canvas, (0, 0))
# draw branches
for ind, node in enumerate(self.tree):
row, index_in_row = self.ind_to_state(ind)
node_pos = self.get_pos_in_window(row, index_in_row)
if row < self.tree_depth:
# Get childerns' positions and draw branches
child1_pos = self.get_pos_in_window(row + 1, 2 * index_in_row)
child2_pos = self.get_pos_in_window(row + 1, 2 * index_in_row + 1)
half_square = self.node_square_size / 2
pygame.draw.line(canvas, (90, 82, 85), node_pos + half_square, child1_pos + half_square, 1)
pygame.draw.line(canvas, (90, 82, 85), node_pos + half_square, child2_pos + half_square, 1)

for ind, node in enumerate(self.tree):
row, index_in_row = self.ind_to_state(ind)
if (row, index_in_row) == tuple(self.current_state):
img = self.agent_img
font_color = (255, 0, 0) # Red digits for agent node
font_color = (164, 0, 0) # Red digits for agent node
else:
img = self.node_img
font_color = (0, 255, 0) # Green digits for non-agent nodes
if ind % 2:
font_color = (250, 128, 114) # Green
else:
font_color = (45, 72, 101) # Dark Blue

node_pos = self.get_pos_in_window(row, index_in_row)

self.window.blit(img, np.array(node_pos))
canvas.blit(img, np.array(node_pos))

if row < self.tree_depth:
# Get childerns' positions and draw branches
child1_pos = self.get_pos_in_window(row + 1, 2 * index_in_row)
child2_pos = self.get_pos_in_window(row + 1, 2 * index_in_row + 1)
half_square = self.node_square_size / 2
pygame.draw.line(self.window, (255, 255, 255), node_pos + half_square, child1_pos + half_square, 1)
pygame.draw.line(self.window, (255, 255, 255), node_pos + half_square, child2_pos + half_square, 1)
else:
# Print node values at the bottom of the tree
# Print node values at the bottom of the tree
if row == self.tree_depth:
odd_nodes_values_offset = 0.5 * (ind % 2)
values_imgs = [self.font.render(f"{val:.2f}", True, font_color) for val in node]
for i, val_img in enumerate(values_imgs):
self.window.blit(val_img, node_pos + np.array([-5, (i + 1) * self.font_size]))
canvas.blit(val_img, node_pos + np.array([-5, (i + 1 + odd_nodes_values_offset) * 1.5 * self.font_size]))

if self.render_mode == "human":
pygame.event.pump()
pygame.display.update()
self.clock.tick(self.metadata["render_fps"])
elif self.render_mode == "rgb_array":
return np.transpose(np.array(pygame.surfarray.pixels3d(self.window)), axes=(1, 0, 2))
return np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))

background = pygame.Surface(self.window_size)
background.fill((255, 255, 255)) # White

background.blit(canvas, (0, self.top_margin))

self.window.blit(background, (0, 0))


if __name__ == "__main__":
import time

import mo_gymnasium as mo_gym

env = mo_gym.make("fruit-tree", depth=6, render_mode="human")
env = mo_gym.make("fruit-tree", depth=7, render_mode="human")
env.reset()
while True:
env.render()
Expand Down

0 comments on commit 151f8f4

Please sign in to comment.