mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
generate text mask within the unet and diffusion prior itself from the text encodings, if not given
This commit is contained in:
@@ -872,7 +872,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
||||
|
||||
if not exists(mask):
|
||||
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
|
||||
mask = torch.any(text_encodings != 0., dim = -1)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
@@ -1205,7 +1205,6 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
|
||||
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
# timestep conditioning from ddpm
|
||||
@@ -1819,6 +1818,9 @@ class Unet(nn.Module):
|
||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
||||
|
||||
if not exists(text_mask):
|
||||
text_mask = torch.any(text_encodings != 0., dim = -1)
|
||||
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
text_tokens = text_tokens[:, :self.max_text_len]
|
||||
|
||||
@@ -1827,13 +1829,10 @@ class Unet(nn.Module):
|
||||
|
||||
if remainder > 0:
|
||||
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
|
||||
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
||||
|
||||
if exists(text_mask):
|
||||
if remainder > 0:
|
||||
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
||||
|
||||
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||
text_keep_mask = text_mask & text_keep_mask
|
||||
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||
text_keep_mask = text_mask & text_keep_mask
|
||||
|
||||
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
||||
|
||||
@@ -2440,9 +2439,6 @@ class Decoder(nn.Module):
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
||||
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
@@ -2526,9 +2522,6 @@ class Decoder(nn.Module):
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
||||
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
image = resize_image_to(image, target_image_size)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user