diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index b3793a4..e5da5e9 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -483,7 +483,7 @@ class DiffusionPrior(nn.Module): timesteps = 1000, cond_drop_prob = 0.2, loss_type = "l1", - predict_x0 = True, + predict_x_start = True, beta_schedule = "cosine", ): super().__init__() @@ -497,7 +497,7 @@ class DiffusionPrior(nn.Module): self.image_size = clip.image_size self.cond_drop_prob = cond_drop_prob - self.predict_x0 = predict_x0 + self.predict_x_start = predict_x_start # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. if beta_schedule == "cosine": @@ -586,14 +586,14 @@ class DiffusionPrior(nn.Module): def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): pred = self.net(x, t, **text_cond) - if self.predict_x0: + if self.predict_x_start: x_recon = pred # not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this # i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken else: x_recon = self.predict_start_from_noise(x, t = t, noise = pred) - if clip_denoised and not self.predict_x0: + if clip_denoised and not self.predict_x_start: x_recon.clamp_(-1., 1.) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) @@ -639,7 +639,7 @@ class DiffusionPrior(nn.Module): **text_cond ) - to_predict = noise if not self.predict_x0 else image_embed + to_predict = noise if not self.predict_x_start else image_embed if self.loss_type == 'l1': loss = F.l1_loss(to_predict, x_recon) @@ -1121,8 +1121,8 @@ class Decoder(nn.Module): cond_drop_prob = 0.2, loss_type = 'l1', beta_schedule = 'cosine', - predict_x0 = False, - predict_x0_for_latent_diffusion = False, + predict_x_start = False, + predict_x_start_for_latent_diffusion = False, image_sizes = None, # for cascading ddpm, image size at each stage lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur @@ -1173,7 +1173,7 @@ class Decoder(nn.Module): # predict x0 config - self.predict_x0 = cast_tuple(predict_x0, len(unets)) if not predict_x0_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes)) + self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes)) # cascading ddpm related stuff @@ -1293,31 +1293,31 @@ class Decoder(nn.Module): posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x0 = False, cond_scale = 1.): + def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) - if predict_x0: + if predict_x_start: x_recon = pred else: x_recon = self.predict_start_from_noise(x, t = t, noise = pred) - if clip_denoised and not predict_x0: + if clip_denoised and not predict_x_start: x_recon.clamp_(-1., 1.) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x0 = False, clip_denoised = True, repeat_noise = False): + def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x0 = predict_x0) + model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop(self, unet, shape, image_embed, predict_x0 = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1): + def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1): device = self.betas.device b = shape[0] @@ -1332,7 +1332,7 @@ class Decoder(nn.Module): text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, - predict_x0 = predict_x0 + predict_x_start = predict_x_start ) return img @@ -1345,7 +1345,7 @@ class Decoder(nn.Module): extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) - def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x0 = False, noise = None): + def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) @@ -1359,7 +1359,7 @@ class Decoder(nn.Module): cond_drop_prob = self.cond_drop_prob ) - target = noise if not predict_x0 else x_start + target = noise if not predict_x_start else x_start if self.loss_type == 'l1': loss = F.l1_loss(target, x_recon) @@ -1381,7 +1381,7 @@ class Decoder(nn.Module): img = None - for unet, vae, channel, image_size, predict_x0 in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x0)): + for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): with self.one_unet_in_gpu(unet = unet): lowres_cond_img = None shape = (batch_size, channel, image_size, image_size) @@ -1401,7 +1401,7 @@ class Decoder(nn.Module): image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, - predict_x0 = predict_x0, + predict_x_start = predict_x_start, lowres_cond_img = lowres_cond_img ) @@ -1425,7 +1425,7 @@ class Decoder(nn.Module): target_image_size = self.image_sizes[unet_index] vae = self.vaes[unet_index] - predict_x0 = self.predict_x0[unet_index] + predict_x_start = self.predict_x_start[unet_index] b, c, h, w, device, = *image.shape, image.device @@ -1449,7 +1449,7 @@ class Decoder(nn.Module): if exists(lowres_cond_img): lowres_cond_img = vae.encode(lowres_cond_img) - return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x0 = predict_x0) + return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start) # main class diff --git a/setup.py b/setup.py index c00a455..520aecc 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.42', + version = '0.0.43', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',