From 6afb886cf493daf714d6200aa76ed5e603268664 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 18 Jul 2022 13:43:57 -0700 Subject: [PATCH] complete imagen-like noise level conditioning --- dalle2_pytorch/dalle2_pytorch.py | 179 ++++++++++++++++++++++++------- dalle2_pytorch/version.py | 2 +- 2 files changed, 144 insertions(+), 37 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index ff8fc77..d0c5b51 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -52,10 +52,10 @@ def first(arr, d = None): def maybe(fn): @wraps(fn) - def inner(x): + def inner(x, *args, **kwargs): if not exists(x): return x - return fn(x) + return fn(x, *args, **kwargs) return inner def default(val, d): @@ -63,13 +63,13 @@ def default(val, d): return val return d() if callable(d) else d -def cast_tuple(val, length = None): +def cast_tuple(val, length = None, validate = True): if isinstance(val, list): val = tuple(val) out = val if isinstance(val, tuple) else ((val,) * default(length, 1)) - if exists(length): + if exists(length) and validate: assert len(out) == length return out @@ -494,6 +494,9 @@ class NoiseScheduler(nn.Module): self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0. register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) + def sample_random_times(self, batch): + return torch.randint(0, self.num_timesteps, (batch,), device = self.betas.device, dtype = torch.long) + def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + @@ -1243,7 +1246,7 @@ class DiffusionPrior(nn.Module): # timestep conditioning from ddpm batch, device = image_embed.shape[0], image_embed.device - times = torch.randint(0, self.noise_scheduler.num_timesteps, (batch,), device = device, dtype = torch.long) + times = self.noise_scheduler.sample_random_times(batch) # scale image embed (Katherine) @@ -1540,6 +1543,7 @@ class Unet(nn.Module): attn_dim_head = 32, attn_heads = 16, lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ + lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen sparse_attn = False, attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) cond_on_text_encodings = False, @@ -1629,6 +1633,17 @@ class Unet(nn.Module): self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) self.text_embed_dim = text_embed_dim + # low resolution noise conditiong, based on Imagen's upsampler training technique + + self.lowres_noise_cond = lowres_noise_cond + + self.to_lowres_noise_cond = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, time_cond_dim), + nn.GELU(), + nn.Linear(time_cond_dim, time_cond_dim) + ) if lowres_noise_cond else None + # finer control over whether to condition on image embeddings and text encodings # so one can have the latter unets in the cascading DDPMs only focus on super-resoluting @@ -1745,15 +1760,17 @@ class Unet(nn.Module): self, *, lowres_cond, + lowres_noise_cond, channels, channels_out, cond_on_image_embeds, - cond_on_text_encodings + cond_on_text_encodings, ): if lowres_cond == self.lowres_cond and \ channels == self.channels and \ cond_on_image_embeds == self.cond_on_image_embeds and \ cond_on_text_encodings == self.cond_on_text_encodings and \ + cond_on_lowres_noise == self.cond_on_lowres_noise and \ channels_out == self.channels_out: return self @@ -1762,7 +1779,8 @@ class Unet(nn.Module): channels = channels, channels_out = channels_out, cond_on_image_embeds = cond_on_image_embeds, - cond_on_text_encodings = cond_on_text_encodings + cond_on_text_encodings = cond_on_text_encodings, + lowres_noise_cond = lowres_noise_cond ) return self.__class__(**{**self._locals, **updated_kwargs}) @@ -1788,6 +1806,7 @@ class Unet(nn.Module): *, image_embed, lowres_cond_img = None, + lowres_noise_level = None, text_encodings = None, image_cond_drop_prob = 0., text_cond_drop_prob = 0., @@ -1816,6 +1835,13 @@ class Unet(nn.Module): time_tokens = self.to_time_tokens(time_hiddens) t = self.to_time_cond(time_hiddens) + # low res noise conditioning (similar to time above) + + if exists(lowres_noise_level): + assert exists(self.to_lowres_noise_cond), 'lowres_noise_cond must be set to True on instantiation of the unet in order to conditiong on lowres noise' + lowres_noise_level = lowres_noise_level.type_as(x) + t = t + self.to_lowres_noise_cond(lowres_noise_level) + # conditional dropout image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device) @@ -1965,25 +1991,48 @@ class LowresConditioner(nn.Module): def __init__( self, downsample_first = True, + use_blur = True, blur_prob = 0.5, blur_sigma = 0.6, blur_kernel_size = 3, - input_image_range = None + use_noise = False, + input_image_range = None, + normalize_img_fn = identity, + unnormalize_img_fn = identity ): super().__init__() self.downsample_first = downsample_first self.input_image_range = input_image_range + self.use_blur = use_blur self.blur_prob = blur_prob self.blur_sigma = blur_sigma self.blur_kernel_size = blur_kernel_size + self.use_noise = use_noise + self.normalize_img = normalize_img_fn + self.unnormalize_img = unnormalize_img_fn + self.noise_scheduler = NoiseScheduler(beta_schedule = 'linear', timesteps = 1000, loss_type = 'l2') if use_noise else None + + def noise_image(self, cond_fmap, noise_levels = None): + assert exists(self.noise_scheduler) + + batch = cond_fmap.shape[0] + cond_fmap = self.normalize_img(cond_fmap) + + random_noise_levels = default(noise_levels, lambda: self.noise_scheduler.sample_random_times(batch)) + cond_fmap = self.noise_scheduler.q_sample(cond_fmap, t = random_noise_levels, noise = torch.randn_like(cond_fmap)) + + cond_fmap = self.unnormalize_img(cond_fmap) + return cond_fmap, random_noise_levels + def forward( self, cond_fmap, *, target_image_size, downsample_image_size = None, + should_blur = True, blur_sigma = None, blur_kernel_size = None ): @@ -1993,7 +2042,7 @@ class LowresConditioner(nn.Module): # blur is only applied 50% of the time # section 3.1 in https://arxiv.org/abs/2106.15282 - if random.random() < self.blur_prob: + if self.use_blur and should_blur and random.random() < self.blur_prob: # when training, blur the low resolution conditional image @@ -2015,8 +2064,21 @@ class LowresConditioner(nn.Module): cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2)) + # resize to target image size + cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range, nearest = True) - return cond_fmap + + # noise conditioning, as done in Imagen + # as a replacement for the BSR noising, and potentially replace blurring for first stage too + + random_noise_levels = None + + if self.use_noise: + cond_fmap, random_noise_levels = self.noise_image(cond_fmap) + + # return conditioning feature map, as well as the augmentation noise levels + + return cond_fmap, random_noise_levels class Decoder(nn.Module): def __init__( @@ -2037,10 +2099,13 @@ class Decoder(nn.Module): predict_x_start_for_latent_diffusion = False, image_sizes = None, # for cascading ddpm, image size at each stage random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops) + use_noise_for_lowres_cond = False, # whether to use Imagen-like noising for low resolution conditioning + use_blur_for_lowres_cond = True, # whether to use the blur conditioning used in the original cascading ddpm paper, as well as DALL-E2 lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time blur_sigma = 0.6, # cascading ddpm - blur sigma blur_kernel_size = 3, # cascading ddpm - blur kernel size + lowres_noise_sample_level = 0.2, # in imagen paper, they use a 0.2 noise level at sample time for low resolution conditioning clip_denoised = True, clip_x_start = True, clip_adapter_overrides = dict(), @@ -2088,10 +2153,17 @@ class Decoder(nn.Module): self.channels = channels + + # normalize and unnormalize image functions + + self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity + self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity + # verify conditioning method unets = cast_tuple(unet) num_unets = len(unets) + self.num_unets = num_unets self.unconditional = unconditional @@ -2107,12 +2179,28 @@ class Decoder(nn.Module): self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1 self.vb_loss_weight = vb_loss_weight + # default and validate conditioning parameters + + use_noise_for_lowres_cond = cast_tuple(use_noise_for_lowres_cond, num_unets - 1, validate = False) + use_blur_for_lowres_cond = cast_tuple(use_blur_for_lowres_cond, num_unets - 1, validate = False) + + if len(use_noise_for_lowres_cond) < num_unets: + use_noise_for_lowres_cond = (False, *use_noise_for_lowres_cond) + + if len(use_blur_for_lowres_cond) < num_unets: + use_blur_for_lowres_cond = (False, *use_blur_for_lowres_cond) + + assert not use_noise_for_lowres_cond[0], 'first unet will never need low res noise conditioning' + assert not use_blur_for_lowres_cond[0], 'first unet will never need low res blur conditioning' + + assert num_unets == 1 or all((use_noise or use_blur) for use_noise, use_blur in zip(use_noise_for_lowres_cond[1:], use_blur_for_lowres_cond[1:])) + # construct unets and vaes self.unets = nn.ModuleList([]) self.vaes = nn.ModuleList([]) - for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)): + for ind, (one_unet, one_vae, one_unet_learned_var, lowres_noise_cond) in enumerate(zip(unets, vaes, learned_variance, use_noise_for_lowres_cond)): assert isinstance(one_unet, Unet) assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE)) @@ -2124,6 +2212,7 @@ class Decoder(nn.Module): one_unet = one_unet.cast_model_parameters( lowres_cond = not is_first, + lowres_noise_cond = lowres_noise_cond, cond_on_image_embeds = not unconditional and is_first, cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings, channels = unet_channels, @@ -2166,7 +2255,7 @@ class Decoder(nn.Module): image_sizes = default(image_sizes, (image_size,)) image_sizes = tuple(sorted(set(image_sizes))) - assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}' + assert self.num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({self.num_unets}) for resolutions {image_sizes}' self.image_sizes = image_sizes self.sample_channels = cast_tuple(self.channels, len(image_sizes)) @@ -2186,15 +2275,30 @@ class Decoder(nn.Module): # cascading ddpm related stuff lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) - assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' + assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' - self.to_lowres_cond = LowresConditioner( - downsample_first = lowres_downsample_first, - blur_prob = blur_prob, - blur_sigma = blur_sigma, - blur_kernel_size = blur_kernel_size, - input_image_range = self.input_image_range - ) + self.lowres_conds = nn.ModuleList([]) + + for unet_index, use_noise, use_blur in zip(range(num_unets), use_noise_for_lowres_cond, use_blur_for_lowres_cond): + if unet_index == 0: + self.lowres_conds.append(None) + continue + + lowres_cond = LowresConditioner( + downsample_first = lowres_downsample_first, + use_blur = use_blur, + use_noise = use_noise, + blur_prob = blur_prob, + blur_sigma = blur_sigma, + blur_kernel_size = blur_kernel_size, + input_image_range = self.input_image_range, + normalize_img_fn = self.normalize_img, + unnormalize_img_fn = self.unnormalize_img + ) + + self.lowres_conds.append(lowres_cond) + + self.lowres_noise_sample_level = lowres_noise_sample_level # classifier free guidance @@ -2212,11 +2316,6 @@ class Decoder(nn.Module): self.use_dynamic_thres = use_dynamic_thres self.dynamic_thres_percentile = dynamic_thres_percentile - # normalize and unnormalize image functions - - self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity - self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity - # device tracker self.register_buffer('_dummy', torch.Tensor([True]), persistent = False) @@ -2230,7 +2329,7 @@ class Decoder(nn.Module): return any([unet.cond_on_text_encodings for unet in self.unets]) def get_unet(self, unet_number): - assert 0 < unet_number <= len(self.unets) + assert 0 < unet_number <= self.num_unets index = unet_number - 1 return self.unets[index] @@ -2316,7 +2415,7 @@ class Decoder(nn.Module): return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False): + def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None): device = self.device b = shape[0] @@ -2334,6 +2433,7 @@ class Decoder(nn.Module): text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, + lowres_noise_level = lowres_noise_level, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, @@ -2344,7 +2444,7 @@ class Decoder(nn.Module): return unnormalize_img @torch.no_grad() - def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False): + def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None): batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1] @@ -2363,7 +2463,7 @@ class Decoder(nn.Module): time_cond = torch.full((batch,), time, device = device, dtype = torch.long) - pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) + pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level) if learned_variance: pred, _ = pred.chunk(2, dim = 1) @@ -2402,7 +2502,7 @@ class Decoder(nn.Module): return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs) - def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False): + def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None): noise = default(noise, lambda: torch.randn_like(x_start)) # normalize to [-1, 1] @@ -2421,6 +2521,7 @@ class Decoder(nn.Module): image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, + lowres_noise_level = lowres_noise_level, image_cond_drop_prob = self.image_cond_drop_prob, text_cond_drop_prob = self.text_cond_drop_prob, ) @@ -2500,20 +2601,24 @@ class Decoder(nn.Module): img = None is_cuda = next(self.parameters()).is_cuda - num_unets = len(self.unets) + num_unets = self.num_unets cond_scale = cast_tuple(cond_scale, num_unets) - for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps, cond_scale)): + for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)): context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context() with context: - lowres_cond_img = None + lowres_cond_img = lowres_noise_level = None shape = (batch_size, channel, image_size, image_size) if unet.lowres_cond: lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = True) + if lowres_cond.use_noise: + lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device) + lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level) + is_latent_diffusion = isinstance(vae, VQGanVAE) image_size = vae.get_encoded_fmap_size(image_size) shape = (batch_size, vae.encoded_dim, image_size, image_size) @@ -2530,6 +2635,7 @@ class Decoder(nn.Module): learned_variance = learned_variance, clip_denoised = not is_latent_diffusion, lowres_cond_img = lowres_cond_img, + lowres_noise_level = lowres_noise_level, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, timesteps = sample_timesteps @@ -2551,7 +2657,7 @@ class Decoder(nn.Module): unet_number = None, return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes ): - assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' + assert not (self.num_unets > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {self.num_unets}, if you are training cascading DDPM (multiple unets)' unet_number = default(unet_number, 1) unet_index = unet_number - 1 @@ -2559,6 +2665,7 @@ class Decoder(nn.Module): vae = self.vaes[unet_index] noise_scheduler = self.noise_schedulers[unet_index] + lowres_conditioner = self.lowres_conds[unet_index] target_image_size = self.image_sizes[unet_index] predict_x_start = self.predict_x_start[unet_index] random_crop_size = self.random_crop_sizes[unet_index] @@ -2581,7 +2688,7 @@ class Decoder(nn.Module): assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' - lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None + lowres_cond_img, lowres_noise_level = lowres_conditioner(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if exists(lowres_conditioner) else (None, None) image = resize_image_to(image, target_image_size, nearest = True) if exists(random_crop_size): @@ -2599,7 +2706,7 @@ class Decoder(nn.Module): image = vae.encode(image) lowres_cond_img = maybe(vae.encode)(lowres_cond_img) - losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler) + losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level) if not return_lowres_cond_image: return losses diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index c6d6a56..8c308d7 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.24.3' +__version__ = '0.25.0'