mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
one more fix for text mask, if the length of the text encoding exceeds max_text_len, add an assert for better error msg
This commit is contained in:
@@ -1822,7 +1822,9 @@ class Unet(nn.Module):
|
||||
text_mask = torch.any(text_encodings != 0., dim = -1)
|
||||
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
|
||||
text_tokens = text_tokens[:, :self.max_text_len]
|
||||
text_mask = text_mask[:, :self.max_text_len]
|
||||
|
||||
text_tokens_len = text_tokens.shape[1]
|
||||
remainder = self.max_text_len - text_tokens_len
|
||||
@@ -1832,6 +1834,8 @@ class Unet(nn.Module):
|
||||
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
||||
|
||||
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||
|
||||
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'
|
||||
text_keep_mask = text_mask & text_keep_mask
|
||||
|
||||
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.21.1'
|
||||
__version__ = '0.21.2'
|
||||
|
||||
Reference in New Issue
Block a user