mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
allow for training the Prior network with precomputed CLIP embeddings (or text encodings)
This commit is contained in:
@@ -692,13 +692,41 @@ class DiffusionPrior(nn.Module):
|
||||
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
||||
return rearrange(top_image_embeds, 'b 1 d -> b d')
|
||||
|
||||
def forward(self, text, image, *args, **kwargs):
|
||||
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||
def forward(
|
||||
self,
|
||||
text = None,
|
||||
image = None,
|
||||
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
|
||||
image_embed = None,
|
||||
text_encodings = None, # as well as CLIP text encodings
|
||||
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
|
||||
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
|
||||
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
image_embed = self.get_image_embed(image)
|
||||
text_cond = self.get_text_cond(text)
|
||||
if exists(image):
|
||||
image_embed = self.get_image_embed(image)
|
||||
|
||||
# calculate text conditionings, based on what is passed in
|
||||
|
||||
if exists(text):
|
||||
text_cond = self.get_text_cond(text)
|
||||
else:
|
||||
text_cond = dict(
|
||||
text_embed = text_embed,
|
||||
text_encodings = text_encodings,
|
||||
mask = text_mask
|
||||
)
|
||||
|
||||
# timestep conditioning from ddpm
|
||||
|
||||
batch, device = image_embed.shape[0], image_embed.device
|
||||
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
|
||||
|
||||
# calculate forward loss
|
||||
|
||||
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user