diff --git a/examples/generative/corrdiff/generate.py b/examples/generative/corrdiff/generate.py index 412df2524b..148a01ee80 100644 --- a/examples/generative/corrdiff/generate.py +++ b/examples/generative/corrdiff/generate.py @@ -160,7 +160,7 @@ def main(cfg: DictConfig) -> None: elif cfg.sampler.type == "stochastic": sampler_fn = partial( stochastic_sampler, - img_shape=img_shape[1], + img_shape = (img_shape[1],img_shape[0]), patch_shape=patch_shape[1], boundary_pix=cfg.sampler.boundary_pix, overlap_pix=cfg.sampler.overlap_pix, diff --git a/modulus/utils/generative/stochastic_sampler.py b/modulus/utils/generative/stochastic_sampler.py index 0bf420581b..42aeb2270c 100644 --- a/modulus/utils/generative/stochastic_sampler.py +++ b/modulus/utils/generative/stochastic_sampler.py @@ -24,15 +24,15 @@ def image_batching( - input: Tensor, - img_shape_x: int, - img_shape_y: int, - patch_shape_x: int, - patch_shape_y: int, - batch_size: int, - overlap_pix: int, - boundary_pix: int, - input_interp: Optional[Tensor] = None, + input, + img_shape_y, + img_shape_x, + patch_shape_y, + patch_shape_x, + batch_size, + overlap_pix, + boundary_pix, + input_interp=None, ) -> Tensor: """ Splits a full image into a batch of patched images. @@ -82,40 +82,42 @@ def image_batching( pad_x_right = padded_shape_x - img_shape_x - boundary_pix pad_y_right = padded_shape_y - img_shape_y - boundary_pix input_padded = torch.zeros( - input.shape[0], input.shape[1], padded_shape_x, padded_shape_y - ).cuda() + input.shape[0], input.shape[1], padded_shape_y, padded_shape_x + ).to(input.device) image_padding = torch.nn.ReflectionPad2d( (boundary_pix, pad_x_right, boundary_pix, pad_y_right) - ).cuda() # (padding_left,padding_right,padding_top,padding_bottom) + ).to( + input.device + ) # (padding_left,padding_right,padding_top,padding_bottom) input_padded = image_padding(input) patch_num = patch_num_x * patch_num_y if input_interp is not None: output = torch.zeros( patch_num * batch_size, input.shape[1] + input_interp.shape[1], - patch_shape_x, patch_shape_y, - ).cuda() + patch_shape_x, + ).to(input.device) else: output = torch.zeros( - patch_num * batch_size, input.shape[1], patch_shape_x, patch_shape_y - ).cuda() + patch_num * batch_size, input.shape[1], patch_shape_y, patch_shape_x + ).to(input.device) for x_index in range(patch_num_x): for y_index in range(patch_num_y): x_start = x_index * (patch_shape_x - overlap_pix - boundary_pix) y_start = y_index * (patch_shape_y - overlap_pix - boundary_pix) if input_interp is not None: output[ - (x_index * patch_num_x + y_index) - * batch_size : (x_index * patch_num_x + y_index + 1) + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) * batch_size, ] = torch.cat( ( input_padded[ :, :, - x_start : x_start + patch_shape_x, y_start : y_start + patch_shape_y, + x_start : x_start + patch_shape_x, ], input_interp, ), @@ -123,27 +125,27 @@ def image_batching( ) else: output[ - (x_index * patch_num_x + y_index) - * batch_size : (x_index * patch_num_x + y_index + 1) + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) * batch_size, ] = input_padded[ :, :, - x_start : x_start + patch_shape_x, y_start : y_start + patch_shape_y, + x_start : x_start + patch_shape_x, ] return output def image_fuse( - input: Tensor, - img_shape_x: int, - img_shape_y: int, - patch_shape_x: int, - patch_shape_y: int, - batch_size: int, - overlap_pix: int, - boundary_pix: int, + input, + img_shape_y, + img_shape_x, + patch_shape_y, + patch_shape_x, + batch_size, + overlap_pix, + boundary_pix, ) -> Tensor: """ Reconstructs a full image from a batch of patched images. @@ -193,11 +195,13 @@ def image_fuse( pad_y_right = padded_shape_y - img_shape_y - boundary_pix residual_x = patch_shape_x - pad_x_right # residual pixels in the last patch residual_y = patch_shape_y - pad_y_right # residual pixels in the last patch - output = torch.zeros(batch_size, input.shape[1], img_shape_x, img_shape_y).cuda() - one_map = torch.ones(1, 1, input.shape[2], input.shape[3]).cuda() + output = torch.zeros( + batch_size, input.shape[1], img_shape_y, img_shape_x, device=input.device + ) + one_map = torch.ones(1, 1, input.shape[2], input.shape[3], device=input.device) count_map = torch.zeros( - 1, 1, img_shape_x, img_shape_y - ).cuda() # to count the overlapping times + 1, 1, img_shape_y, img_shape_x, device=input.device + ) # to count the overlapping times for x_index in range(patch_num_x): for y_index in range(patch_num_y): @@ -205,105 +209,105 @@ def image_fuse( y_start = y_index * (patch_shape_y - overlap_pix - boundary_pix) if (x_index == patch_num_x - 1) and (y_index != patch_num_y - 1): output[ - :, :, x_start:, y_start : y_start + patch_shape_y - 2 * boundary_pix + :, :, y_start : y_start + patch_shape_y - 2 * boundary_pix, x_start: ] += input[ - (x_index * patch_num_x + y_index) - * batch_size : (x_index * patch_num_x + y_index + 1) + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) * batch_size, :, - boundary_pix : residual_x + boundary_pix, boundary_pix : patch_shape_y - boundary_pix, + boundary_pix : residual_x + boundary_pix, ] count_map[ - :, :, x_start:, y_start : y_start + patch_shape_y - 2 * boundary_pix + :, :, y_start : y_start + patch_shape_y - 2 * boundary_pix, x_start: ] += one_map[ :, :, - boundary_pix : residual_x + boundary_pix, boundary_pix : patch_shape_y - boundary_pix, + boundary_pix : residual_x + boundary_pix, ] elif (y_index == patch_num_y - 1) and ((x_index != patch_num_x - 1)): output[ - :, :, x_start : x_start + patch_shape_x - 2 * boundary_pix, y_start: + :, :, y_start:, x_start : x_start + patch_shape_x - 2 * boundary_pix ] += input[ - (x_index * patch_num_x + y_index) - * batch_size : (x_index * patch_num_x + y_index + 1) + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) * batch_size, :, - boundary_pix : patch_shape_x - boundary_pix, boundary_pix : residual_y + boundary_pix, + boundary_pix : patch_shape_x - boundary_pix, ] count_map[ - :, :, x_start : x_start + patch_shape_x - 2 * boundary_pix, y_start: + :, :, y_start:, x_start : x_start + patch_shape_x - 2 * boundary_pix ] += one_map[ :, :, - boundary_pix : patch_shape_x - boundary_pix, boundary_pix : residual_y + boundary_pix, + boundary_pix : patch_shape_x - boundary_pix, ] elif x_index == patch_num_x - 1 and y_index == patch_num_y - 1: - output[:, :, x_start:, y_start:] += input[ - (x_index * patch_num_x + y_index) - * batch_size : (x_index * patch_num_x + y_index + 1) + output[:, :, y_start:, x_start:] += input[ + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) * batch_size, :, - boundary_pix : residual_x + boundary_pix, boundary_pix : residual_y + boundary_pix, + boundary_pix : residual_x + boundary_pix, ] - count_map[:, :, x_start:, y_start:] += one_map[ + count_map[:, :, y_start:, x_start:] += one_map[ :, :, - boundary_pix : residual_x + boundary_pix, boundary_pix : residual_y + boundary_pix, + boundary_pix : residual_x + boundary_pix, ] else: output[ :, :, - x_start : x_start + patch_shape_x - 2 * boundary_pix, y_start : y_start + patch_shape_y - 2 * boundary_pix, + x_start : x_start + patch_shape_x - 2 * boundary_pix, ] += input[ - (x_index * patch_num_x + y_index) - * batch_size : (x_index * patch_num_x + y_index + 1) + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) * batch_size, :, - boundary_pix : patch_shape_x - boundary_pix, boundary_pix : patch_shape_y - boundary_pix, + boundary_pix : patch_shape_x - boundary_pix, ] count_map[ :, :, - x_start : x_start + patch_shape_x - 2 * boundary_pix, y_start : y_start + patch_shape_y - 2 * boundary_pix, + x_start : x_start + patch_shape_x - 2 * boundary_pix, ] += one_map[ :, :, - boundary_pix : patch_shape_x - boundary_pix, boundary_pix : patch_shape_y - boundary_pix, + boundary_pix : patch_shape_x - boundary_pix, ] return output / count_map def stochastic_sampler( - net: Any, - latents: Tensor, - img_lr: Tensor, - class_labels: Optional[Tensor] = None, - randn_like: Callable[[Tensor], Tensor] = torch.randn_like, - img_shape: int = 448, - patch_shape: int = 448, - overlap_pix: int = 4, - boundary_pix: int = 2, - mean_hr: Optional[Tensor] = None, - num_steps: int = 18, - sigma_min: float = 0.002, - sigma_max: float = 800, - rho: float = 7, - S_churn: float = 0, - S_min: float = 0, - S_max: float = float("inf"), - S_noise: float = 1, -) -> Tensor: + net, + latents, + img_lr, + class_labels=None, + randn_like=torch.randn_like, + img_shape=448, + patch_shape=448, + overlap_pix=4, + boundary_pix=2, + mean_hr=None, + num_steps=18, + sigma_min=0.002, + sigma_max=800, + rho=7, + S_churn=0, + S_min=0, + S_max=float("inf"), + S_noise=1, +): """ Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution and patch-based diffusion. @@ -352,10 +356,15 @@ def stochastic_sampler( Tensor The final denoised image produced by the sampler. """ - + # Adjust noise levels based on what's supported by the network. sigma_min = max(sigma_min, net.sigma_min) - sigma_max = min(sigma_max, net.sigma_max) + # sigma_max = min(sigma_max, net.sigma_max) + + if isinstance(img_shape, tuple): + img_shape_x, img_shape_y = img_shape + else: + img_shape_x = img_shape_y = img_shape # Time step discretization. step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) @@ -370,11 +379,11 @@ def stochastic_sampler( ) # t_N = 0 b = latents.shape[0] - Nx = torch.arange(img_shape) - Ny = torch.arange(img_shape) - grid = torch.stack(torch.meshgrid(Nx, Ny, indexing="ij"), dim=0)[ - None, - ].expand(b, -1, -1, -1) + Nx = torch.arange(img_shape_x) + Ny = torch.arange(img_shape_y) + grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[None,].expand( + b, -1, -1, -1 + ) # conditioning = [mean_hr, img_lr, global_lr, pos_embd] batch_size = img_lr.shape[0] @@ -384,14 +393,14 @@ def stochastic_sampler( global_index = None # input and position padding + patching - if patch_shape != img_shape: + if patch_shape != img_shape_x or patch_shape != img_shape_y: input_interp = torch.nn.functional.interpolate( img_lr, (patch_shape, patch_shape), mode="bilinear" ) x_lr = image_batching( x_lr, - img_shape, - img_shape, + img_shape_y, + img_shape_x, patch_shape, patch_shape, batch_size, @@ -401,8 +410,8 @@ def stochastic_sampler( ) global_index = image_batching( grid.float(), - img_shape, - img_shape, + img_shape_y, + img_shape_x, patch_shape, patch_shape, batch_size, @@ -414,19 +423,19 @@ def stochastic_sampler( x_next = latents.to(torch.float64) * t_steps[0] for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 x_cur = x_next + # Increase noise temporarily. - gamma = ( - min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 - ) + gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 t_hat = net.round_sigma(t_cur + gamma * t_cur) + x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) # Euler step. Perform patching operation on score tensor if patch-based generation is used - if patch_shape != img_shape: + if patch_shape != img_shape_x or patch_shape != img_shape_y: x_hat_batch = image_batching( x_hat, - img_shape, - img_shape, + img_shape_y, + img_shape_x, patch_shape, patch_shape, batch_size, @@ -435,14 +444,19 @@ def stochastic_sampler( ) else: x_hat_batch = x_hat + + x_hat_batch = x_hat_batch.to(latents.device) + x_lr = x_lr.to(latents.device) + global_index = global_index.to(latents.device) denoised = net( x_hat_batch, x_lr, t_hat, class_labels, global_index=global_index ).to(torch.float64) - if patch_shape != img_shape: + + if patch_shape != img_shape_x or patch_shape != img_shape_y: denoised = image_fuse( denoised, - img_shape, - img_shape, + img_shape_y, + img_shape_x, patch_shape, patch_shape, batch_size, @@ -454,11 +468,11 @@ def stochastic_sampler( # Apply 2nd order correction. if i < num_steps - 1: - if patch_shape != img_shape: + if patch_shape != img_shape_x or patch_shape != img_shape_y: x_next_batch = image_batching( x_next, - img_shape, - img_shape, + img_shape_y, + img_shape_x, patch_shape, patch_shape, batch_size, @@ -467,12 +481,15 @@ def stochastic_sampler( ) else: x_next_batch = x_next - denoised = net(x_next_batch, x_lr, t_next, class_labels).to(torch.float64) - if patch_shape != img_shape: + x_next_batch = x_next_batch.to(latents.device) + denoised = net( + x_next_batch, x_lr, t_next, class_labels, global_index=global_index + ).to(torch.float64) + if patch_shape != img_shape_x or patch_shape != img_shape_y: denoised = image_fuse( denoised, - img_shape, - img_shape, + img_shape_y, + img_shape_x, patch_shape, patch_shape, batch_size,