more informative error for something that tripped me up

This commit is contained in:
Phil Wang
2022-07-09 17:28:14 -07:00
parent b7e22f7da0
commit 47ae17b36e
2 changed files with 5 additions and 1 deletions

View File

@@ -1537,10 +1537,12 @@ class Unet(nn.Module):
# text encoding conditioning (optional) # text encoding conditioning (optional)
self.text_to_cond = None self.text_to_cond = None
self.text_embed_dim = None
if cond_on_text_encodings: if cond_on_text_encodings:
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True' assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
self.text_embed_dim = text_embed_dim
# finer control over whether to condition on image embeddings and text encodings # finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting # so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
@@ -1769,6 +1771,8 @@ class Unet(nn.Module):
text_tokens = None text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings: if exists(text_encodings) and self.cond_on_text_encodings:
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
text_tokens = self.text_to_cond(text_encodings) text_tokens = self.text_to_cond(text_encodings)
text_tokens = text_tokens[:, :self.max_text_len] text_tokens = text_tokens[:, :self.max_text_len]

View File

@@ -1 +1 @@
__version__ = '0.19.1' __version__ = '0.19.2'