From 47ae17b36e5f1dc3fc2550a3b229779b080070a1 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 9 Jul 2022 17:28:14 -0700 Subject: [PATCH] more informative error for something that tripped me up --- dalle2_pytorch/dalle2_pytorch.py | 4 ++++ dalle2_pytorch/version.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 6e32b0c..86c030d 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1537,10 +1537,12 @@ class Unet(nn.Module): # text encoding conditioning (optional) self.text_to_cond = None + self.text_embed_dim = None if cond_on_text_encodings: assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True' self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) + self.text_embed_dim = text_embed_dim # finer control over whether to condition on image embeddings and text encodings # so one can have the latter unets in the cascading DDPMs only focus on super-resoluting @@ -1769,6 +1771,8 @@ class Unet(nn.Module): text_tokens = None if exists(text_encodings) and self.cond_on_text_encodings: + assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.' + text_tokens = self.text_to_cond(text_encodings) text_tokens = text_tokens[:, :self.max_text_len] diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index db7a416..5daae67 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.19.1' +__version__ = '0.19.2'