diff --git a/behavior/baselines/behavioral_cloning/base_input_utils.py b/behavior/baselines/behavioral_cloning/base_input_utils.py index 60fdd34..3ccf605 100644 --- a/behavior/baselines/behavioral_cloning/base_input_utils.py +++ b/behavior/baselines/behavioral_cloning/base_input_utils.py @@ -13,7 +13,7 @@ # Constants IMG_DIM = 128 ACT_DIM = 28 -PROPRIOCEPTION_DIM = 20 +PROPRIOCEPTION_DIM = 22 TASK_OBS_DIM = 456 diff --git a/behavior/baselines/behavioral_cloning/simple_bc_agent.py b/behavior/baselines/behavioral_cloning/simple_bc_agent.py index 2f40498..435fcd2 100644 --- a/behavior/baselines/behavioral_cloning/simple_bc_agent.py +++ b/behavior/baselines/behavioral_cloning/simple_bc_agent.py @@ -4,7 +4,7 @@ import argparse import logging import sys - +import os sys.path.insert(0, "../utils") import base_input_utils as BIU import torch @@ -18,7 +18,7 @@ class BCNet_rgbp(nn.Module): """A behavioral cloning agent that uses RGB images and proprioception as state space""" - def __init__(self, img_channels=3, proprioception_dim=20, num_actions=28): + def __init__(self, img_channels=3, proprioception_dim=22, num_actions=28): super(BCNet_rgbp, self).__init__() # image feature self.features1 = nn.Sequential( @@ -46,7 +46,7 @@ def forward(self, imgs, proprioceptions): class BCNet_taskObs(nn.Module): - def __init__(self, task_obs_dim=456, proprioception_dim=20, num_actions=28): + def __init__(self, task_obs_dim=456, proprioception_dim=22, num_actions=28): super(BCNet_taskObs, self).__init__() # image feature self.fc1 = nn.Linear(task_obs_dim + proprioception_dim, 1024) @@ -76,16 +76,17 @@ def parse_args(): # Training device = "cuda" if torch.cuda.is_available() else "cpu" - # bc_agent = BCNet_rgbp().to(device) - bc_agent = BCNet_taskObs().to(device) + bc_agent = BCNet_rgbp().to(device) + # bc_agent = BCNet_taskObs().to(device) optimizer = optim.Adam(bc_agent.parameters()) NUM_EPOCH = 5 - PATH = "trained_models/model.pth" + PATH = "./trained_models" for epoch in range(NUM_EPOCH): optimizer.zero_grad() - output = bc_agent(data.task_obss, data.proprioceptions) + # output = bc_agent(data.task_obss, data.proprioceptions) + output = bc_agent(data.rgbs, data.proprioceptions) loss_func = nn.MSELoss() loss = loss_func(output, data.actions) loss.backward() @@ -93,4 +94,6 @@ def parse_args(): log.info(loss.item()) - torch.save(bc_agent, PATH) + if not os.path.exists(PATH): + os.makedirs(PATH) + torch.save(bc_agent, (f"{PATH}/model.pth"))