mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
allow for decoder conditioning with the text encodings from CLIP, if it is passed in. use lazy linear to avoid researchers having to worry about text encoding dimensions, but remove later if it does not work well
This commit is contained in:
@@ -276,7 +276,7 @@ decoder = Decoder(
|
|||||||
cond_drop_prob = 0.2
|
cond_drop_prob = 0.2
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
loss = decoder(images)
|
loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# do above for many steps
|
# do above for many steps
|
||||||
|
|||||||
@@ -722,7 +722,7 @@ class Unet(nn.Module):
|
|||||||
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
||||||
in_out = list(zip(dims[:-1], dims[1:]))
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||||||
|
|
||||||
# time and image embeddings
|
# time, image embeddings, and optional text encoding
|
||||||
|
|
||||||
cond_dim = default(cond_dim, dim)
|
cond_dim = default(cond_dim, dim)
|
||||||
|
|
||||||
@@ -739,9 +739,12 @@ class Unet(nn.Module):
|
|||||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||||
) if image_embed_dim != cond_dim else nn.Identity()
|
) if image_embed_dim != cond_dim else nn.Identity()
|
||||||
|
|
||||||
|
self.text_to_cond = nn.LazyLinear(cond_dim)
|
||||||
|
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||||
|
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
|
|
||||||
@@ -806,6 +809,7 @@ class Unet(nn.Module):
|
|||||||
time_tokens = self.time_mlp(time)
|
time_tokens = self.time_mlp(time)
|
||||||
|
|
||||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||||
|
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
||||||
|
|
||||||
# mask out image embedding depending on condition dropout
|
# mask out image embedding depending on condition dropout
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
@@ -813,12 +817,31 @@ class Unet(nn.Module):
|
|||||||
image_tokens = self.image_to_cond(image_embed)
|
image_tokens = self.image_to_cond(image_embed)
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
rearrange(cond_prob_mask, 'b -> b 1 1'),
|
cond_prob_mask,
|
||||||
image_tokens,
|
image_tokens,
|
||||||
self.null_image_embed
|
self.null_image_embed
|
||||||
)
|
)
|
||||||
|
|
||||||
c = torch.cat((time_tokens, image_tokens), dim = -2) # c for condition
|
# take care of text encodings (optional)
|
||||||
|
|
||||||
|
if exists(text_encodings):
|
||||||
|
text_tokens = self.text_to_cond(text_encodings)
|
||||||
|
text_tokens = torch.where(
|
||||||
|
cond_prob_mask,
|
||||||
|
text_tokens,
|
||||||
|
self.null_text_embed
|
||||||
|
)
|
||||||
|
|
||||||
|
# main conditioning tokens (c)
|
||||||
|
|
||||||
|
c = torch.cat((time_tokens, image_tokens), dim = -2)
|
||||||
|
|
||||||
|
# text and image conditioning tokens (mid_c)
|
||||||
|
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
|
||||||
|
|
||||||
|
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
|
||||||
|
|
||||||
|
# go through the layers of the unet, down and up
|
||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
|
|
||||||
@@ -828,9 +851,9 @@ class Unet(nn.Module):
|
|||||||
hiddens.append(x)
|
hiddens.append(x)
|
||||||
x = downsample(x)
|
x = downsample(x)
|
||||||
|
|
||||||
x = self.mid_block1(x, c)
|
x = self.mid_block1(x, mid_c)
|
||||||
x = self.mid_attn(x)
|
x = self.mid_attn(x)
|
||||||
x = self.mid_block2(x, c)
|
x = self.mid_block2(x, mid_c)
|
||||||
|
|
||||||
for convnext, convnext2, upsample in self.ups:
|
for convnext, convnext2, upsample in self.ups:
|
||||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||||
@@ -896,6 +919,10 @@ class Decoder(nn.Module):
|
|||||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||||
|
|
||||||
|
def get_text_encodings(self, text):
|
||||||
|
text_encodings = self.clip.text_transformer(text)
|
||||||
|
return text_encodings[:, 1:]
|
||||||
|
|
||||||
def get_image_embed(self, image):
|
def get_image_embed(self, image):
|
||||||
image_encoding = self.clip.visual_transformer(image)
|
image_encoding = self.clip.visual_transformer(image)
|
||||||
image_cls = image_encoding[:, 0]
|
image_cls = image_encoding[:, 0]
|
||||||
@@ -923,8 +950,8 @@ class Decoder(nn.Module):
|
|||||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
|
def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
|
||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
|
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
|
||||||
|
|
||||||
if clip_denoised:
|
if clip_denoised:
|
||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
@@ -933,31 +960,32 @@ class Decoder(nn.Module):
|
|||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = noise_like(x.shape, device, repeat_noise)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
|
def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
|
||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
img = torch.randn(shape, device=device)
|
img = torch.randn(shape, device=device)
|
||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
|
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample(self, image_embed, cond_scale = 1.):
|
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||||
batch_size = image_embed.shape[0]
|
batch_size = image_embed.shape[0]
|
||||||
image_size = self.image_size
|
image_size = self.image_size
|
||||||
channels = self.channels
|
channels = self.channels
|
||||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
|
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||||
|
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
def q_sample(self, x_start, t, noise=None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
@@ -967,7 +995,7 @@ class Decoder(nn.Module):
|
|||||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||||
)
|
)
|
||||||
|
|
||||||
def p_losses(self, x_start, t, *, image_embed, noise = None):
|
def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||||
@@ -976,6 +1004,7 @@ class Decoder(nn.Module):
|
|||||||
x_noisy,
|
x_noisy,
|
||||||
t,
|
t,
|
||||||
image_embed = image_embed,
|
image_embed = image_embed,
|
||||||
|
text_encodings = text_encodings,
|
||||||
cond_drop_prob = self.cond_drop_prob
|
cond_drop_prob = self.cond_drop_prob
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -988,14 +1017,16 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def forward(self, image):
|
def forward(self, image, text = None):
|
||||||
b, device, img_size, = image.shape[0], image.device, self.image_size
|
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||||
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||||
|
|
||||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||||
image_embed = self.get_image_embed(image)
|
|
||||||
|
|
||||||
loss = self.p_losses(image, times, image_embed = image_embed)
|
image_embed = self.get_image_embed(image)
|
||||||
|
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||||
|
|
||||||
|
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
# main class
|
# main class
|
||||||
|
|||||||
Reference in New Issue
Block a user