mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
more informative error for something that tripped me up
This commit is contained in:
@@ -1537,10 +1537,12 @@ class Unet(nn.Module):
|
|||||||
# text encoding conditioning (optional)
|
# text encoding conditioning (optional)
|
||||||
|
|
||||||
self.text_to_cond = None
|
self.text_to_cond = None
|
||||||
|
self.text_embed_dim = None
|
||||||
|
|
||||||
if cond_on_text_encodings:
|
if cond_on_text_encodings:
|
||||||
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
|
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
|
||||||
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
||||||
|
self.text_embed_dim = text_embed_dim
|
||||||
|
|
||||||
# finer control over whether to condition on image embeddings and text encodings
|
# finer control over whether to condition on image embeddings and text encodings
|
||||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||||
@@ -1769,6 +1771,8 @@ class Unet(nn.Module):
|
|||||||
text_tokens = None
|
text_tokens = None
|
||||||
|
|
||||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||||
|
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
||||||
|
|
||||||
text_tokens = self.text_to_cond(text_encodings)
|
text_tokens = self.text_to_cond(text_encodings)
|
||||||
text_tokens = text_tokens[:, :self.max_text_len]
|
text_tokens = text_tokens[:, :self.max_text_len]
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.19.1'
|
__version__ = '0.19.2'
|
||||||
|
|||||||
Reference in New Issue
Block a user