Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights of Restored model after checkpointing do not give same loss as model before saving #1402

Open
delara38 opened this issue Nov 26, 2024 · 1 comment

Comments

@delara38
Copy link

delara38 commented Nov 26, 2024

Hello,

I have a class that contains an nnx.Module and trains it. I try to save and restore by accessing this attribute but as the title says I find that when I restore the model, it's loss is as bad as a randomly initialized model.

I have no way to describe the problem as anything or than the title says because I will train a model, halve the loss from it's initialization, save the model using the instructions in the tutorial on saving and loading models (or the instructions given here google/flax#4383, or the instructions on the orbax website) and then restore them in another file and re-run the training loop. However at the final step my loss is the same as the loss I got at initialization. Note, that the parameters are not the ones I had at initialization but completely different ones that are equally poor when evaluated on my objective function.

I have attatched the code for my model, my training file, and my loading function.

Model file:

@nnx.jit
def training_step(model, optimizer, key, X, Y, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, maxT):
    batch_size = X.shape[0] 
    t_key, key = random.split(key)
    t = random.randint(t_key, (batch_size//2+1), 0, maxT)

    t = jnp.concatenate([t, maxT-t-1], axis=-1)[:batch_size]

    ab_s = jnp.take(alphas_bar_sqrt, t)[..., None]
    am1 = jnp.take(one_minus_alphas_bar_sqrt, t)[..., None]

    noise_key, key = random.split(key)
    noise = random.normal(noise_key, X.shape)

    x_ts = ab_s*X + am1*noise


    
    def conditional_model_estimation_loss(model):

        preds = model(x_ts, Y, t[..., None])

        loss = jnp.mean(jnp.square(preds-noise))
        return loss 

    loss, grads = nnx.value_and_grad(conditional_model_estimation_loss)(model)

    optimizer.update(grads)
    return loss

def make_beta_schedule(schedule = 'linear', T=100, start=1e-5, end = 0.5e-2):
    if schedule == 'linear':
        betas = jnp.linspace(start, end, T)
    elif schedule == 'cosine':
        fn = lambda x: jnp.cos(x/T + np.pi/2).pow(2)
        betas = fn(jnp.linspace(0,1, T))
    elif schedule == "sigmoid":
        betas = jnp.linspace(-6, 6, T)
        
        betas = jax.nn.sigmoid(betas) * (end - start) + start
    
    else:
        raise ValueError("schedule not implemented yet")
    return betas

class ConditionalLinear(nnx.Module):
    def __init__(self, num_in, num_out, n_steps, rngs):
        self.lin = nnx.Linear(num_in, num_out, rngs = rngs)
        self.embed = nnx.Embed(n_steps, num_out, rngs = rngs)

    def __call__(self, x, t):
        xout = self.lin(x)
        em = jnp.reshape(self.embed(t), xout.shape)
        #em = jnp.squeeze(self.embed(t), -2)
        out = xout*em
        #print(xout.shape, out.shape, em.shape, x.shape)
        return out 


class ConditionalDiffusionModel(nnx.Module):
    def __init__(self, dim, conditioning_dim, hidden_dim, T, rngs):
        self.dim = dim 

        self.cond_emb1 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs)
        self.cond_emb2 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs)
        self.cond_emb3 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs) 
        
        self.l1 = ConditionalLinear(dim+hidden_dim, hidden_dim, T, rngs)
        self.l2 = ConditionalLinear(hidden_dim, hidden_dim, T, rngs)
        self.l3 = ConditionalLinear(hidden_dim, hidden_dim, T, rngs)
        self.last_layer = nnx.Linear(hidden_dim, dim, rngs=rngs)
        



    def __call__(self, x, y, ts):
        yemb = nnx.softplus(self.cond_emb1(y, ts))
        yemb = nnx.softplus(self.cond_emb2(y, ts))
        yemb = nnx.softplus(self.cond_emb3(y, ts))

        xus = jnp.concatenate([x, yemb], axis=-1)
        xus = nnx.softplus(self.l1(xus, ts))
        xus = nnx.softplus(self.l2(xus, ts))
        xus = nnx.softplus(self.l3(xus, ts))

        preds = self.last_layer(xus)

        return preds
    
class ConditionalDiffuser:
    def __init__(self, dim, conditioning_dim, hidden_dim, beta_schedule, T, sigma, rngs):
        self.model = ConditionalDiffusionModel(dim, conditioning_dim, hidden_dim, T, rngs)
        self.betas = jnp.array(make_beta_schedule(beta_schedule, T))
        self.alphas = 1-self.betas 
        self.alpha_bars = jnp.cumprod(self.alphas)
        self.alphas_bar_sqrt = jnp.sqrt(self.alpha_bars)
        self.one_minus_alphas_bar_sqrt = jnp.sqrt(1-self.alpha_bars)
        self.alpha_bars_p = jnp.concatenate([jnp.array([1]), self.alpha_bars[:-1]])
        self.T = T
        self.sigma = sigma 

    def __call__(self, x):
        return self.model(x)

    def get_model(self):
        return self.model
    def train(self, key, dataset, opt, batch_size, epochs):
        epoch_loss = [0]
        ema_loss = None 
        ema = 0.9

        for e in range(epochs):
            if e % 10 == 0:
                print(f"Epoch: {e}\\t: epoch loss {epoch_loss[-1]}, ema loss {ema_loss}")
            
            el = []
            key, perm_key = random.split(key)
            permutation = random.permutation(perm_key, dataset['samples'].shape[0])

            with tqdm(range(0, dataset['samples'].shape[0], batch_size)) as tp:
                for i in tp:

                    indices = permutation[i:i+batch_size]
                    X = dataset['samples'][indices]
                    Y = dataset['conditioners'][indices]

                    step_key, key = random.split(key)
                    loss = training_step(self.model, opt, step_key, X, Y, self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, self.T)

                    el.append(loss)
                    if ema_loss is None:
                        ema_loss = loss 
                    else:
                        ema_loss = ema*ema_loss + (1-ema)*loss 
                    
                    tp.set_postfix(loss = np.mean(el[-40:]))
            el = np.mean(el)
            epoch_loss.append(el)
    

    def backward(self, x, y, ts, key):


        if type(self.sigma) == list:
            sigmas = np.choose(self.sigma, ts)
        else:
            sigmas = self.sigma
    
        
        a_roots = jnp.sqrt(jnp.take(self.alphas,(ts-1).astype(jnp.int32))[..., None])
        betas = jnp.take(self.betas, ts.astype(jnp.int32))[..., None]
        one_minus_abar_roots = jnp.sqrt(1 - jnp.take(self.alpha_bars, ts.astype(jnp.int32))[..., None])

        sigma_ts = jnp.sqrt( betas )


        pred_noise = self.model(x, y, ts.astype(jnp.int32))

        n_key, key = random.split(key)
        x_t = (1/a_roots) * (x - ( betas / one_minus_abar_roots )*pred_noise) + sigma_ts * random.normal(n_key, x.shape)
        


        return x_t
    
    
    def complete_backward(self, x,y, T, key):
        for t in range(1,T):
            
            #xus = torch.concatenate([x, torch.ones(x.shape[0], 1)*(T-t)], dim=-1)
            diff_key, key = random.split(key)
            ts = (jnp.ones((x.shape[0], ))*(T-t))
            x = self.backward(x,y, ts, diff_key )
        return x

Training file

def main(args):


    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    
    env = gym.make(args.env)
    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]


    dataset = np.load(args.dataset_path)

    obs = dataset['observations']
    actions = dataset['actions']
    rew_to_go = dataset['reward_to_go']
    print(rew_to_go.shape)
    Y = jnp.array(np.concatenate([obs, rew_to_go], axis=-1))

    X = jnp.array(actions)
    data = {'samples': X, 'conditioners': Y}

    if args.restore:
            print("Loading from Checkpoint")
            ckpt_path = Path('saved_models/my_checkpoints/').resolve()
            diff_model = load_model(ckpt_path)
            model = ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1))
            model.model = diff_model

    else:
            print("Building Fresh Model")
            model = ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1))
        #model = ConditionalDiffusionModel(action_dim, state_dim+1, 1000, 8, 'linear', 52, 0.05).to(device)
        #model.load_state_dict(torch.load(args.model_out))

    opt = nnx.Optimizer(model.model, optax.adam(args.lr))

    Y = jnp.array(np.concatenate([obs, rew_to_go], axis=-1))

    X = jnp.array(actions)
    data = {'samples': X, 'conditioners': Y}

     train_key = jax.random.key(1)
     model.train(train_key, data, opt, args.batch_size, args.epochs)
        
        
        # Prepare state to save
     _,  state = nnx.split(model.model)

     print("Checkpointing...")
        # Save using checkpoint manager
        
        # Checkpointing
     ckpt_path = Path('saved_models/my_checkpoints/').resolve()
        #ckpt_path.mkdir(parents=True, exist_ok=True)  # Ensure directory exists
        
     checkpointer = ocp.StandardCheckpointer()
     checkpointer.save(ckpt_path/'attempt_8', state)
        

     sus_model = load_model(ckpt_path)
     _, restored_state = nnx.split(sus_model)
     assert(jax.tree.map(np.testing.assert_array_equal, restored_state, state))
        #print(other_state)

     print("Done checkpointing")
    return

Load function

def load_model(checkpoint_dir, env_name='Pendulum-v1'):
    # Initialize environment to get dimensions
    env = gym.make(env_name)
    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]
    
    # Initialize model with same parameters as during training
    model = ConditionalDiffuser(
        dim=action_dim,
        conditioning_dim=state_dim+1,
        hidden_dim=1000,
        beta_schedule='linear',
        T=52,
        sigma=0.05,
        rngs=nnx.Rngs(0, noise=1)
    )
    
    abstract_model = nnx.eval_shape(lambda: ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1)))
    graphdef, abstract_state = nnx.split(abstract_model)

    
    checkpointer = ocp.StandardCheckpointer()
    loaded_state = checkpointer.restore(checkpoint_dir/'attempt_8', abstract_state)
    
    #jax.tree.map(np.testing.assert_array_equal, abstract_state, loaded_state)
    #print(loaded_state)
    #print(nnx.display(loaded_state))
    model = nnx.merge(graphdef, loaded_state)


    return model
@niketkumar
Copy link
Collaborator

Following debugging steps can help.

  • Is the saved state different from loaded_state? They should match. Or they should differ only as per abstract_state.
  • Is abstract_state expected and suitable to restore what was saved earlier? It should at least structurally be the same as state.
checkpointer.save(ckpt_path/'attempt_8', state)
 ...
graphdef, abstract_state = nnx.split(abstract_model)
...
loaded_state = checkpointer.restore(checkpoint_dir/'attempt_8', abstract_state)

If loaded_state is expected for the given state, then it is not a checkpointing issue. In that case, it needs debugging of nnx or modules other than orbax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants