Skip to content

Commit

Permalink
minor refactoring to avoid pylance error
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Sep 7, 2023
1 parent 4bc6661 commit 00aa8a9
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions hw1/cs285/policies/MLP_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self,
n_layers=self.n_layers, size=self.size,
)
self.mean_net.to(ptu.device)
self.logstd = nn.Parameter(
self.logstd = nn.parameter.Parameter(
torch.zeros(self.ac_dim, dtype=torch.float32, device=ptu.device)
)
self.logstd.to(ptu.device)
Expand Down Expand Up @@ -108,11 +108,11 @@ def __init__(self, ac_dim, ob_dim, n_layers, size, **kwargs):
def update(self, observations: np.ndarray, actions: np.ndarray, **kwargs) -> Dict:
assert len(observations) == len(actions), "Check the lengths"

observations = ptu.from_numpy(observations)
actions = ptu.from_numpy(actions)
observations_t = ptu.from_numpy(observations)
actions_t = ptu.from_numpy(actions)

distn = self(observations)
loss = -distn.log_prob(actions).mean()
distn = self(observations_t)
loss = -distn.log_prob(actions_t).mean()

self.optimizer.zero_grad()
loss.backward()
Expand All @@ -121,9 +121,11 @@ def update(self, observations: np.ndarray, actions: np.ndarray, **kwargs) -> Dic

def forward(self, observation: torch.FloatTensor) -> Any:
if self.discrete:
assert self.logits_na is not None
logits = self.logits_na(observation)
return distributions.Categorical(logits=logits)
else:
assert self.mean_net is not None and self.logstd is not None
loc = self.mean_net(observation)
scale = torch.exp(self.logstd[None])
return distributions.Normal(loc=loc, scale=scale)

0 comments on commit 00aa8a9

Please sign in to comment.