mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix amp issue for https://github.com/lucidrains/DALLE2-pytorch/issues/82
This commit is contained in:
@@ -1492,11 +1492,12 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
if self.cond_on_image_embeds:
|
if self.cond_on_image_embeds:
|
||||||
image_tokens = self.image_to_cond(image_embed)
|
image_tokens = self.image_to_cond(image_embed)
|
||||||
|
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
image_keep_mask,
|
image_keep_mask,
|
||||||
image_tokens,
|
image_tokens,
|
||||||
self.null_image_embed
|
null_image_embed
|
||||||
)
|
)
|
||||||
|
|
||||||
# take care of text encodings (optional)
|
# take care of text encodings (optional)
|
||||||
@@ -1520,10 +1521,12 @@ class Unet(nn.Module):
|
|||||||
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||||
text_keep_mask = text_mask & text_keep_mask
|
text_keep_mask = text_mask & text_keep_mask
|
||||||
|
|
||||||
|
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
||||||
|
|
||||||
text_tokens = torch.where(
|
text_tokens = torch.where(
|
||||||
text_keep_mask,
|
text_keep_mask,
|
||||||
text_tokens,
|
text_tokens,
|
||||||
self.null_text_embed
|
null_text_embed
|
||||||
)
|
)
|
||||||
|
|
||||||
# main conditioning tokens (c)
|
# main conditioning tokens (c)
|
||||||
|
|||||||
Reference in New Issue
Block a user