mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34806663e3 | ||
|
|
dc816b1b6e | ||
|
|
05192ffac4 |
@@ -38,6 +38,8 @@ from coca_pytorch import CoCa
|
||||
|
||||
NAT = 1. / math.log(2.)
|
||||
|
||||
UnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
@@ -1004,9 +1006,9 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
# setup self conditioning
|
||||
|
||||
self_cond = None
|
||||
if self.self_cond:
|
||||
self_cond = default(self_cond, lambda: torch.zeros(batch, 1, self.dim, device = device, dtype = dtype))
|
||||
self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))
|
||||
self_cond = rearrange(self_cond, 'b d -> b 1 d')
|
||||
|
||||
# in section 2.2, last paragraph
|
||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||
@@ -1277,9 +1279,12 @@ class DiffusionPrior(nn.Module):
|
||||
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
||||
|
||||
if not is_ddim:
|
||||
return self.p_sample_loop_ddpm(*args, **kwargs)
|
||||
normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
|
||||
else:
|
||||
normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
||||
|
||||
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
||||
image_embed = normalized_image_embed / self.image_embed_scale
|
||||
return image_embed
|
||||
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
@@ -1348,8 +1353,6 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
# retrieve original unscaled image embed
|
||||
|
||||
image_embeds /= self.image_embed_scale
|
||||
|
||||
text_embeds = text_cond['text_embed']
|
||||
|
||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
@@ -2584,6 +2587,14 @@ class Decoder(nn.Module):
|
||||
index = unet_number - 1
|
||||
return self.unets[index]
|
||||
|
||||
def parse_unet_output(self, learned_variance, output):
|
||||
var_interp_frac_unnormalized = None
|
||||
|
||||
if learned_variance:
|
||||
output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)
|
||||
|
||||
return UnetOutput(output, var_interp_frac_unnormalized)
|
||||
|
||||
@contextmanager
|
||||
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
||||
assert exists(unet_number) ^ exists(unet)
|
||||
@@ -2625,10 +2636,9 @@ class Decoder(nn.Module):
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
||||
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
||||
|
||||
if learned_variance:
|
||||
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
||||
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
|
||||
|
||||
if predict_x_start:
|
||||
x_start = pred
|
||||
@@ -2811,10 +2821,9 @@ class Decoder(nn.Module):
|
||||
|
||||
self_cond = x_start if unet.self_cond else None
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = pred.chunk(2, dim = 1)
|
||||
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||
|
||||
if predict_x_start:
|
||||
x_start = pred
|
||||
@@ -2886,16 +2895,13 @@ class Decoder(nn.Module):
|
||||
|
||||
if unet.self_cond and random.random() < 0.5:
|
||||
with torch.no_grad():
|
||||
self_cond = unet(x_noisy, times, **unet_kwargs)
|
||||
|
||||
if learned_variance:
|
||||
self_cond, _ = self_cond.chunk(2, dim = 1)
|
||||
|
||||
unet_output = unet(x_noisy, times, **unet_kwargs)
|
||||
self_cond, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||
self_cond = self_cond.detach()
|
||||
|
||||
# forward to get model prediction
|
||||
|
||||
model_output = unet(
|
||||
unet_output = unet(
|
||||
x_noisy,
|
||||
times,
|
||||
**unet_kwargs,
|
||||
@@ -2904,10 +2910,7 @@ class Decoder(nn.Module):
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = model_output.chunk(2, dim = 1)
|
||||
else:
|
||||
pred = model_output
|
||||
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
@@ -2930,7 +2933,7 @@ class Decoder(nn.Module):
|
||||
# if learning the variance, also include the extra weight kl loss
|
||||
|
||||
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)
|
||||
|
||||
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.6.1'
|
||||
__version__ = '1.6.5'
|
||||
|
||||
Reference in New Issue
Block a user