diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 710fdcf..2d99961 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -805,7 +805,7 @@ class DiffusionPrior(BaseGaussianDiffusion): model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance - @torch.no_grad() + @torch.inference_mode() def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised) @@ -814,7 +814,7 @@ class DiffusionPrior(BaseGaussianDiffusion): nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - @torch.no_grad() + @torch.inference_mode() def p_sample_loop(self, shape, text_cond): device = self.betas.device @@ -842,7 +842,7 @@ class DiffusionPrior(BaseGaussianDiffusion): loss = self.loss_fn(pred, target) return loss - @torch.no_grad() + @torch.inference_mode() @eval_decorator def sample(self, text, num_samples_per_batch = 2): # in the paper, what they did was @@ -1639,12 +1639,6 @@ class Decoder(BaseGaussianDiffusion): yield unet.cpu() - - @torch.no_grad() - def get_image_embed(self, image): - image_embed, _ = self.clip.embed_image(image) - return image_embed - def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) @@ -1659,7 +1653,7 @@ class Decoder(BaseGaussianDiffusion): model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance - @torch.no_grad() + @torch.inference_mode() def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start) @@ -1668,7 +1662,7 @@ class Decoder(BaseGaussianDiffusion): nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - @torch.no_grad() + @torch.inference_mode() def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1): device = self.betas.device @@ -1712,7 +1706,7 @@ class Decoder(BaseGaussianDiffusion): loss = self.loss_fn(pred, target) return loss - @torch.no_grad() + @torch.inference_mode() @eval_decorator def sample( self, @@ -1845,7 +1839,7 @@ class DALLE2(nn.Module): self.to_pil = T.ToPILImage() - @torch.no_grad() + @torch.inference_mode() @eval_decorator def forward( self, diff --git a/setup.py b/setup.py index 1410cb4..5953e7a 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.98', + version = '0.0.99', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',