mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
be able to turn off warning for use of LazyLinear by passing in text embedding dimension for unet
This commit is contained in:
@@ -906,6 +906,7 @@ class Unet(nn.Module):
|
||||
dim,
|
||||
*,
|
||||
image_embed_dim,
|
||||
text_embed_dim = None,
|
||||
cond_dim = None,
|
||||
num_image_tokens = 4,
|
||||
num_time_tokens = 2,
|
||||
@@ -959,7 +960,7 @@ class Unet(nn.Module):
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.text_to_cond = nn.LazyLinear(cond_dim)
|
||||
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
|
||||
|
||||
# finer control over whether to condition on image embeddings and text encodings
|
||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||
|
||||
Reference in New Issue
Block a user