diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 7a516de..d02cda1 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1822,7 +1822,9 @@ class Unet(nn.Module): text_mask = torch.any(text_encodings != 0., dim = -1) text_tokens = self.text_to_cond(text_encodings) + text_tokens = text_tokens[:, :self.max_text_len] + text_mask = text_mask[:, :self.max_text_len] text_tokens_len = text_tokens.shape[1] remainder = self.max_text_len - text_tokens_len @@ -1832,6 +1834,8 @@ class Unet(nn.Module): text_mask = F.pad(text_mask, (0, remainder), value = False) text_mask = rearrange(text_mask, 'b n -> b n 1') + + assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}' text_keep_mask = text_mask & text_keep_mask null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 8c306aa..e9734d4 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.21.1' +__version__ = '0.21.2'