mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-22 22:14:29 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f55c24db6 | ||
|
|
69e822b7f8 |
@@ -276,7 +276,7 @@ decoder = Decoder(
|
|||||||
cond_drop_prob = 0.2
|
cond_drop_prob = 0.2
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
loss = decoder(images)
|
loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# do above for many steps
|
# do above for many steps
|
||||||
@@ -319,7 +319,7 @@ Offer training wrappers
|
|||||||
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
|
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
|
||||||
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
|
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
|
||||||
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
|
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
|
||||||
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||||
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
|
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
|
||||||
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
|
|||||||
@@ -722,7 +722,7 @@ class Unet(nn.Module):
|
|||||||
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
||||||
in_out = list(zip(dims[:-1], dims[1:]))
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||||||
|
|
||||||
# time and image embeddings
|
# time, image embeddings, and optional text encoding
|
||||||
|
|
||||||
cond_dim = default(cond_dim, dim)
|
cond_dim = default(cond_dim, dim)
|
||||||
|
|
||||||
@@ -739,9 +739,12 @@ class Unet(nn.Module):
|
|||||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||||
) if image_embed_dim != cond_dim else nn.Identity()
|
) if image_embed_dim != cond_dim else nn.Identity()
|
||||||
|
|
||||||
|
self.text_to_cond = nn.LazyLinear(cond_dim)
|
||||||
|
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||||
|
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
|
|
||||||
@@ -806,6 +809,7 @@ class Unet(nn.Module):
|
|||||||
time_tokens = self.time_mlp(time)
|
time_tokens = self.time_mlp(time)
|
||||||
|
|
||||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||||
|
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
||||||
|
|
||||||
# mask out image embedding depending on condition dropout
|
# mask out image embedding depending on condition dropout
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
@@ -813,12 +817,31 @@ class Unet(nn.Module):
|
|||||||
image_tokens = self.image_to_cond(image_embed)
|
image_tokens = self.image_to_cond(image_embed)
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
rearrange(cond_prob_mask, 'b -> b 1 1'),
|
cond_prob_mask,
|
||||||
image_tokens,
|
image_tokens,
|
||||||
self.null_image_embed
|
self.null_image_embed
|
||||||
)
|
)
|
||||||
|
|
||||||
c = torch.cat((time_tokens, image_tokens), dim = -2) # c for condition
|
# take care of text encodings (optional)
|
||||||
|
|
||||||
|
if exists(text_encodings):
|
||||||
|
text_tokens = self.text_to_cond(text_encodings)
|
||||||
|
text_tokens = torch.where(
|
||||||
|
cond_prob_mask,
|
||||||
|
text_tokens,
|
||||||
|
self.null_text_embed
|
||||||
|
)
|
||||||
|
|
||||||
|
# main conditioning tokens (c)
|
||||||
|
|
||||||
|
c = torch.cat((time_tokens, image_tokens), dim = -2)
|
||||||
|
|
||||||
|
# text and image conditioning tokens (mid_c)
|
||||||
|
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
|
||||||
|
|
||||||
|
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
|
||||||
|
|
||||||
|
# go through the layers of the unet, down and up
|
||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
|
|
||||||
@@ -828,9 +851,9 @@ class Unet(nn.Module):
|
|||||||
hiddens.append(x)
|
hiddens.append(x)
|
||||||
x = downsample(x)
|
x = downsample(x)
|
||||||
|
|
||||||
x = self.mid_block1(x, c)
|
x = self.mid_block1(x, mid_c)
|
||||||
x = self.mid_attn(x)
|
x = self.mid_attn(x)
|
||||||
x = self.mid_block2(x, c)
|
x = self.mid_block2(x, mid_c)
|
||||||
|
|
||||||
for convnext, convnext2, upsample in self.ups:
|
for convnext, convnext2, upsample in self.ups:
|
||||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||||
@@ -896,6 +919,10 @@ class Decoder(nn.Module):
|
|||||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||||
|
|
||||||
|
def get_text_encodings(self, text):
|
||||||
|
text_encodings = self.clip.text_transformer(text)
|
||||||
|
return text_encodings[:, 1:]
|
||||||
|
|
||||||
def get_image_embed(self, image):
|
def get_image_embed(self, image):
|
||||||
image_encoding = self.clip.visual_transformer(image)
|
image_encoding = self.clip.visual_transformer(image)
|
||||||
image_cls = image_encoding[:, 0]
|
image_cls = image_encoding[:, 0]
|
||||||
@@ -923,8 +950,8 @@ class Decoder(nn.Module):
|
|||||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
|
def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
|
||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
|
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
|
||||||
|
|
||||||
if clip_denoised:
|
if clip_denoised:
|
||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
@@ -933,31 +960,32 @@ class Decoder(nn.Module):
|
|||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = noise_like(x.shape, device, repeat_noise)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
|
def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
|
||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
img = torch.randn(shape, device=device)
|
img = torch.randn(shape, device=device)
|
||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
|
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample(self, image_embed, cond_scale = 1.):
|
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||||
batch_size = image_embed.shape[0]
|
batch_size = image_embed.shape[0]
|
||||||
image_size = self.image_size
|
image_size = self.image_size
|
||||||
channels = self.channels
|
channels = self.channels
|
||||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
|
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||||
|
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
def q_sample(self, x_start, t, noise=None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
@@ -967,7 +995,7 @@ class Decoder(nn.Module):
|
|||||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||||
)
|
)
|
||||||
|
|
||||||
def p_losses(self, x_start, t, *, image_embed, noise = None):
|
def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||||
@@ -976,6 +1004,7 @@ class Decoder(nn.Module):
|
|||||||
x_noisy,
|
x_noisy,
|
||||||
t,
|
t,
|
||||||
image_embed = image_embed,
|
image_embed = image_embed,
|
||||||
|
text_encodings = text_encodings,
|
||||||
cond_drop_prob = self.cond_drop_prob
|
cond_drop_prob = self.cond_drop_prob
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -988,14 +1017,16 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def forward(self, image):
|
def forward(self, image, text = None):
|
||||||
b, device, img_size, = image.shape[0], image.device, self.image_size
|
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||||
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||||
|
|
||||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||||
image_embed = self.get_image_embed(image)
|
|
||||||
|
|
||||||
loss = self.p_losses(image, times, image_embed = image_embed)
|
image_embed = self.get_image_embed(image)
|
||||||
|
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||||
|
|
||||||
|
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
# main class
|
# main class
|
||||||
|
|||||||
Reference in New Issue
Block a user