LazyLinear is not mature, make users pass in text_embed_dim if text conditioning is turned on

This commit is contained in:
Phil Wang
2022-05-15 18:56:52 -07:00
parent 4a4c7ac9e6
commit 36c5079bd7
3 changed files with 3 additions and 7 deletions

View File

@@ -1387,7 +1387,8 @@ class Unet(nn.Module):
self.text_to_cond = None
if cond_on_text_encodings:
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting