Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
11b1d533a0 make sure text encodings being passed in has the correct batch dimension 2022-07-12 16:00:19 -07:00
2 changed files with 2 additions and 1 deletions

View File

@@ -1812,6 +1812,7 @@ class Unet(nn.Module):
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {text_encodings.shape} - required batch size is {batch_size}'
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
text_mask = torch.any(text_encodings != 0., dim = -1)

View File

@@ -1 +1 @@
__version__ = '0.22.0'
__version__ = '0.22.1'