be able to turn off warning for use of LazyLinear by passing in text embedding dimension for unet

This commit is contained in:
Phil Wang
2022-04-26 11:42:46 -07:00
parent eafb136214
commit de0296106b
2 changed files with 3 additions and 2 deletions

View File

@@ -906,6 +906,7 @@ class Unet(nn.Module):
dim,
*,
image_embed_dim,
text_embed_dim = None,
cond_dim = None,
num_image_tokens = 4,
num_time_tokens = 2,
@@ -959,7 +960,7 @@ class Unet(nn.Module):
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
self.text_to_cond = nn.LazyLinear(cond_dim)
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.51',
version = '0.0.52',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',