diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 4640878..f227066 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1492,11 +1492,12 @@ class Unet(nn.Module): if self.cond_on_image_embeds: image_tokens = self.image_to_cond(image_embed) + null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working image_tokens = torch.where( image_keep_mask, image_tokens, - self.null_image_embed + null_image_embed ) # take care of text encodings (optional) @@ -1520,10 +1521,12 @@ class Unet(nn.Module): text_mask = rearrange(text_mask, 'b n -> b n 1') text_keep_mask = text_mask & text_keep_mask + null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working + text_tokens = torch.where( text_keep_mask, text_tokens, - self.null_text_embed + null_text_embed ) # main conditioning tokens (c) diff --git a/setup.py b/setup.py index 4789852..6001e37 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.9', + version = '0.2.10', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',