Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TD3中learn()函数部分的参数冻结问题 #13

Open
Cassini-Titan opened this issue Dec 11, 2023 · 0 comments
Open

TD3中learn()函数部分的参数冻结问题 #13

Cassini-Titan opened this issue Dec 11, 2023 · 0 comments

Comments

@Cassini-Titan
Copy link

大佬你好,我想请教一下,TD3中的延迟策略更新部分,涉及到参数冻结,我感觉把这两个冻结和解冻操作去除好像也不会影响代码,因为中间没有涉及到对Critic网络的更新操作。
代码:

# Trick 3:delayed policy updates 延迟策略更新
        if self.actor_pointer % self.policy_freq == 0:
            # Freeze critic networks so you don't waste computational effort 
            
#冻结部分*********************************************************************************
            for params in self.critic.parameters():
                params.requires_grad = False #删除冻结和解冻部分似乎没有影响?
 #*********************************************************************************************

            # Compute actor loss
            actor_loss = -self.critic.Q1(batch_s, self.actor(batch_s)).mean()  # Only use Q1
            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Unfreeze critic networks
#解冻部分*********************************************************************************
            for params in self.critic.parameters():
                params.requires_grad = True
#*********************************************************************************************

            # Softly update the target networks 软更新目标网络
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.TAU * param.data + (1 - self.TAU) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.TAU * param.data + (1 - self.TAU) * target_param.data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant