diff --git a/.gitmodules b/.gitmodules index 5256486..e69de29 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "CLIP"] - path = CLIP - url = https://github.com/openai/CLIP diff --git a/CLIP b/CLIP deleted file mode 160000 index 40f5484..0000000 --- a/CLIP +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 40f5484c1c74edd83cb9cf687c6ab92b28d8b656 diff --git a/diffusion/sampling.py b/diffusion/sampling.py old mode 100644 new mode 100755 index da4c36a..5bf05d5 --- a/diffusion/sampling.py +++ b/diffusion/sampling.py @@ -3,9 +3,79 @@ from . import utils +# These 4 sample_foo functions are subroutines called by sample() +def sample_step_pred(model, x, steps, eta, extra_args, ts, alphas, sigmas, i): + # Get the model output (v, the predicted velocity) + with torch.cuda.amp.autocast(): + v = model(x, ts * steps[i], **extra_args).float() -# DDPM/DDIM sampling + # Predict the noise and the denoised image + pred = x * alphas[i] - v * sigmas[i] + + return pred, v + + +def sample_step_noise(model, x, steps, eta, extra_args, ts, alphas, sigmas, i, pred, v): + eps = x * sigmas[i] + v * alphas[i] + + # If we are not on the last timestep, compute the noisy image for the + # next timestep. + if i < len(steps) - 1: + # If eta > 0, adjust the scaling factor for the predicted noise + # downward according to the amount of additional noise to add + ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ + (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() + adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() + + # Recombine the predicted noise and predicted denoised image in the + # correct proportions for the next step + x = pred * alphas[i + 1] + eps * adjusted_sigma + # Add the correct amount of fresh noise + + if eta: + x = x + torch.randn_like(x) * ddim_sigma + + return x + +def sample_setup(model, x, steps, eta, extra_args): + """Draws samples from a model given starting noise.""" + + # print("SAMPLE SETUP ", steps.shape) + ts = x.new_ones([x.shape[0]]) + + # Create the noise schedule + alphas, sigmas = utils.t_to_alpha_sigma(steps) + + sample_state = [model, steps, eta, extra_args, ts, alphas, sigmas] + return sample_state + +def sample_step(sample_state, x, i, last_pred, last_v): + model, steps, eta, extra_args, ts, alphas, sigmas = sample_state + pred, v = sample_step_pred(model, x, steps, eta, extra_args, ts, alphas, sigmas, i) + return pred, v, x + + +def sample_noise(sample_state, x, i, last_pred, last_v): + model, steps, eta, extra_args, ts, alphas, sigmas = sample_state + if last_pred != None: + x = sample_step_noise(model, x, steps, eta, extra_args, ts, alphas, sigmas, i, last_pred, last_v) + return x + +# this new version of sample calls the above four functions to do the work +def sample_split(model, x, steps, eta, extra_args): + pred = None + v = None + sample_state = sample_setup(model, x, steps, eta, extra_args) + for i in trange(len(steps)): + pred, v, x = sample_step(sample_state, x, i, pred, v) + x = sample_noise(sample_state, x, i, pred, v) + + return pred + +# this is the original version of sample which did everything at once + +# DDPM/DDIM sampling @torch.no_grad() def sample(model, x, steps, eta, extra_args, callback=None): """Draws samples from a model given starting noise."""