Compare commits

...

2 Commits

3 changed files with 27 additions and 18 deletions

View File

@@ -61,6 +61,9 @@ def default(val, d):
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def module_device(module):
return next(module.parameters()).device
@contextmanager
def null_context(*args, **kwargs):
yield
@@ -936,7 +939,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.inference_mode()
@torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
@@ -945,7 +948,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.inference_mode()
@torch.no_grad()
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
device = self.betas.device
@@ -981,7 +984,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
loss = self.loss_fn(pred, target)
return loss
@torch.inference_mode()
@torch.no_grad()
@eval_decorator
def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.):
device = self.betas.device
@@ -993,7 +996,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
return img
@torch.inference_mode()
@torch.no_grad()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
# in the paper, what they did was
@@ -1816,11 +1819,15 @@ class Decoder(BaseGaussianDiffusion):
unet = self.get_unet(unet_number)
self.cuda()
self.unets.cpu()
devices = [module_device(unet) for unet in self.unets]
self.unets.cpu()
unet.cuda()
yield
unet.cpu()
for unet, device in zip(self.unets, devices):
unet.to(device)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -1853,7 +1860,7 @@ class Decoder(BaseGaussianDiffusion):
return model_mean, posterior_variance, posterior_log_variance
@torch.inference_mode()
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
@@ -1862,7 +1869,7 @@ class Decoder(BaseGaussianDiffusion):
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.inference_mode()
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
device = self.betas.device
@@ -1955,12 +1962,14 @@ class Decoder(BaseGaussianDiffusion):
return loss + vb_loss
@torch.inference_mode()
@torch.no_grad()
@eval_decorator
def sample(
self,
image_embed = None,
text = None,
text_mask = None,
text_encodings = None,
batch_size = 1,
cond_scale = 1.,
stop_at_unet_number = None
@@ -1970,8 +1979,8 @@ class Decoder(BaseGaussianDiffusion):
if not self.unconditional:
batch_size = image_embed.shape[0]
text_encodings = text_mask = None
if exists(text):
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exist(self.clip)
_, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
@@ -2023,6 +2032,7 @@ class Decoder(BaseGaussianDiffusion):
text = None,
image_embed = None,
text_encodings = None,
text_mask = None,
unet_number = None
):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
@@ -2047,7 +2057,6 @@ class Decoder(BaseGaussianDiffusion):
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
image_embed, _ = self.clip.embed_image(image)
text_encodings = text_mask = None
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings, text_mask = self.clip.embed_text(text)
@@ -2094,7 +2103,7 @@ class DALLE2(nn.Module):
self.to_pil = T.ToPILImage()
@torch.inference_mode()
@torch.no_grad()
@eval_decorator
def forward(
self,
@@ -2103,7 +2112,7 @@ class DALLE2(nn.Module):
prior_cond_scale = 1.,
return_pil_images = False
):
device = next(self.parameters()).device
device = module_device(self)
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
if isinstance(text, str) or is_list_str(text):

View File

@@ -278,17 +278,17 @@ class DiffusionPriorTrainer(nn.Module):
self.step += 1
@torch.inference_mode()
@torch.no_grad()
@cast_torch_tensor
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode()
@torch.no_grad()
@cast_torch_tensor
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@torch.inference_mode()
@torch.no_grad()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.39',
version = '0.2.41',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',