diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index ee543b1..98fe763 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1070,7 +1070,7 @@ class DiffusionPriorNetwork(nn.Module): null_text_embeds = self.null_text_embeds.to(text_embed.dtype) - text_embeds = torch.where( + text_embed = torch.where( text_keep_mask, text_embed, null_text_embeds diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 78a6e51..b436016 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.10.3' +__version__ = '1.10.4'