Skip to content

Commit

Permalink
fix log_alpha in modeling_sac: change to nn.parameter
Browse files Browse the repository at this point in the history
added pretrained vision model in policy

Co-authored-by: Adil Zouitine <[email protected]>
  • Loading branch information
michel-aractingi and AdilZouitine committed Feb 13, 2025
1 parent dc086dc commit 459f22e
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 8 deletions.
4 changes: 2 additions & 2 deletions lerobot/common/policies/sac/modeling_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
# it triggers "can't optimize a non-leaf Tensor"
self.log_alpha = torch.tensor([0.0], requires_grad=True, device=torch.device("mps"))
self.log_alpha = nn.Parameter(torch.tensor([0.0]))
self.temperature = self.log_alpha.exp().item()

def reset(self):
Expand Down Expand Up @@ -634,7 +634,7 @@ def _load_pretrained_vision_encoder(self, config):
"""Set up CNN encoder"""
from transformers import AutoModel

self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name)
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
# self.image_enc_layers.pooler = Identity()

if hasattr(self.image_enc_layers.config, "hidden_sizes"):
Expand Down
4 changes: 2 additions & 2 deletions lerobot/configs/env/so100_real.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# @package _global_

fps: 30
fps: 10

env:
name: real_world
Expand All @@ -26,6 +26,6 @@ env:
joint_masking_action_space: [1, 1, 1, 1, 0, 0] # disable wrist and gripper

reward_classifier:
pretrained_path: outputs/classifier/checkpoints/best/pretrained_model
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
config_path: lerobot/configs/policy/hilserl_classifier.yaml

4 changes: 2 additions & 2 deletions lerobot/configs/policy/sac_real.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# env.gym.obs_type=environment_state_agent_pos \

seed: 1
dataset_repo_id: null # aractingi/push_green_cube_hf_cropped_resized
dataset_repo_id: aractingi/push_cube_square_light_offline_demo_cropped_resized

training:
# Offline training dataloader
Expand Down Expand Up @@ -52,7 +52,7 @@ policy:
n_action_steps: 1

shared_encoder: true
# vision_encoder_name: null
vision_encoder_name: "helper2424/resnet10"
freeze_vision_encoder: true
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
Expand Down
3 changes: 1 addition & 2 deletions lerobot/scripts/server/learner_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def add_actor_information_and_train(
next_observations = batch["next_state"]
done = batch["done"]

assert_and_breakpoint(observations=observations, actions=actions, next_state=next_observations)
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)

with policy_lock:
loss_critic = policy.compute_loss_critic(
Expand Down Expand Up @@ -533,7 +533,6 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
lr_scheduler = None
optimizers = {
Expand Down

0 comments on commit 459f22e

Please sign in to comment.