-
Notifications
You must be signed in to change notification settings - Fork 1
/
play.py
98 lines (71 loc) · 2.26 KB
/
play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gymnasium as gym
import pygame
import torch
import envs
import argparse
import matplotlib.pyplot as plt
import numpy as np
from model import DQN
def human_play():
env = gym.make("Env-v0", render_mode="human")
obs, _ = env.reset()
total_reward = 0.0
n_frames = 0
while True:
n_frames += 1
userInput = pygame.key.get_pressed()
action = envs.Action.STAND
if userInput[pygame.K_UP] or userInput[pygame.K_SPACE]:
action = envs.Action.JUMP
elif userInput[pygame.K_DOWN]:
action = envs.Action.DUCK
obs, reward, terminated, _, _ = env.step(action)
total_reward += float(reward)
if terminated:
break
print(f"Total reward: {total_reward}, number of frames: {n_frames}")
env.close()
# Show image of the last frame
plt.imshow(obs)
plt.show()
def play_with_model(
env: envs.Wrapper,
policy_net: DQN,
device: torch.device,
seed: int | None = None,
) -> float:
if seed is not None:
state, _ = env.reset(seed=seed)
else:
state, _ = env.reset()
state = torch.tensor(state, device=device)
total_reward = 0.0
while True:
action = policy_net(state.unsqueeze(0)).max(dim=1)[1][0]
state, reward, terminated, _, _ = env.step(action)
state = torch.tensor(state, device=device)
total_reward += float(reward)
if terminated:
break
return total_reward
def ai_play(model_path: str):
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") # type: ignore
policy_net = torch.load(model_path).to(device)
policy_net.eval()
env = gym.make("Env-v0", render_mode="human")
env = envs.Wrapper(env)
total_reward = play_with_model(env, policy_net, device)
print(f"Total reward: {total_reward}, number of frames: {len(env.frames)}")
env.close()
# Show image of the last frame
plt.imshow(env.frames[-1])
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("type", choices=["human", "ai"])
parser.add_argument("-m", "--model_path")
args = parser.parse_args()
if args.type == "human":
human_play()
else:
ai_play(args.model_path)