mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
protect against bad text mask being passed into decoder
This commit is contained in:
@@ -871,7 +871,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
if not exists(text_encodings):
|
||||
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
||||
|
||||
if not exists(mask):
|
||||
if not exists(mask) or mask.numel() == 0:
|
||||
mask = torch.any(text_encodings != 0., dim = -1)
|
||||
|
||||
# classifier free guidance
|
||||
@@ -1818,7 +1818,7 @@ class Unet(nn.Module):
|
||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
||||
|
||||
if not exists(text_mask):
|
||||
if not exists(text_mask) or text_mask.numel() == 0:
|
||||
text_mask = torch.any(text_encodings != 0., dim = -1)
|
||||
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
@@ -1835,7 +1835,7 @@ class Unet(nn.Module):
|
||||
|
||||
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||
|
||||
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'
|
||||
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}. text encoding is of shape {text_encodings.shape}'
|
||||
text_keep_mask = text_mask & text_keep_mask
|
||||
|
||||
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.21.2'
|
||||
__version__ = '0.21.3'
|
||||
|
||||
Reference in New Issue
Block a user