From 11b1d533a0e60e703ef2a8c71c79f3e9f1135e50 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Jul 2022 16:00:19 -0700 Subject: [PATCH] make sure text encodings being passed in has the correct batch dimension --- dalle2_pytorch/dalle2_pytorch.py | 1 + dalle2_pytorch/version.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 8571cb2..7dc015f 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1812,6 +1812,7 @@ class Unet(nn.Module): text_tokens = None if exists(text_encodings) and self.cond_on_text_encodings: + assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {text_encodings.shape} - required batch size is {batch_size}' assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.' text_mask = torch.any(text_encodings != 0., dim = -1) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 81edede..66d9d1e 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.22.0' +__version__ = '0.22.1'