diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 4ec9710..6e4f48c 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -936,7 +936,7 @@ class DiffusionPrior(BaseGaussianDiffusion): model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance - @torch.inference_mode() + @torch.no_grad() def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale) @@ -945,7 +945,7 @@ class DiffusionPrior(BaseGaussianDiffusion): nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - @torch.inference_mode() + @torch.no_grad() def p_sample_loop(self, shape, text_cond, cond_scale = 1.): device = self.betas.device @@ -981,7 +981,7 @@ class DiffusionPrior(BaseGaussianDiffusion): loss = self.loss_fn(pred, target) return loss - @torch.inference_mode() + @torch.no_grad() @eval_decorator def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.): device = self.betas.device @@ -993,7 +993,7 @@ class DiffusionPrior(BaseGaussianDiffusion): img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale) return img - @torch.inference_mode() + @torch.no_grad() @eval_decorator def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.): # in the paper, what they did was @@ -1816,11 +1816,15 @@ class Decoder(BaseGaussianDiffusion): unet = self.get_unet(unet_number) self.cuda() - self.unets.cpu() + devices = [next(unet.parameters()).device for unet in self.unets] + self.unets.cpu() unet.cuda() + yield - unet.cpu() + + for unet, device in zip(self.unets, devices): + unet.to(device) def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None): assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' @@ -1853,7 +1857,7 @@ class Decoder(BaseGaussianDiffusion): return model_mean, posterior_variance, posterior_log_variance - @torch.inference_mode() + @torch.no_grad() def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance) @@ -1862,7 +1866,7 @@ class Decoder(BaseGaussianDiffusion): nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - @torch.inference_mode() + @torch.no_grad() def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1): device = self.betas.device @@ -1955,7 +1959,7 @@ class Decoder(BaseGaussianDiffusion): return loss + vb_loss - @torch.inference_mode() + @torch.no_grad() @eval_decorator def sample( self, @@ -2094,7 +2098,7 @@ class DALLE2(nn.Module): self.to_pil = T.ToPILImage() - @torch.inference_mode() + @torch.no_grad() @eval_decorator def forward( self, diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 812fee3..e86faff 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -278,17 +278,17 @@ class DiffusionPriorTrainer(nn.Module): self.step += 1 - @torch.inference_mode() + @torch.no_grad() @cast_torch_tensor def p_sample_loop(self, *args, **kwargs): return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) - @torch.inference_mode() + @torch.no_grad() @cast_torch_tensor def sample(self, *args, **kwargs): return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs) - @torch.inference_mode() + @torch.no_grad() def sample_batch_size(self, *args, **kwargs): return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs) diff --git a/setup.py b/setup.py index 8c5bfd7..805e501 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.39', + version = '0.2.40', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',