mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
back to no_grad for now, also keep track and restore unet devices in one_unet_in_gpu contextmanager
This commit is contained in:
@@ -936,7 +936,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
@@ -945,7 +945,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
|
||||
device = self.betas.device
|
||||
|
||||
@@ -981,7 +981,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.):
|
||||
device = self.betas.device
|
||||
@@ -993,7 +993,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
|
||||
return img
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
|
||||
# in the paper, what they did was
|
||||
@@ -1816,11 +1816,15 @@ class Decoder(BaseGaussianDiffusion):
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
self.cuda()
|
||||
self.unets.cpu()
|
||||
|
||||
devices = [next(unet.parameters()).device for unet in self.unets]
|
||||
self.unets.cpu()
|
||||
unet.cuda()
|
||||
|
||||
yield
|
||||
unet.cpu()
|
||||
|
||||
for unet, device in zip(self.unets, devices):
|
||||
unet.to(device)
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
@@ -1853,7 +1857,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||
@@ -1862,7 +1866,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
@@ -1955,7 +1959,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
return loss + vb_loss
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample(
|
||||
self,
|
||||
@@ -2094,7 +2098,7 @@ class DALLE2(nn.Module):
|
||||
|
||||
self.to_pil = T.ToPILImage()
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -278,17 +278,17 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
self.step += 1
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
def sample(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def sample_batch_size(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user