make sure text encodings being passed in has the correct batch dimension

This commit is contained in:
Phil Wang
2022-07-12 16:00:19 -07:00
parent e76e89f9eb
commit 11b1d533a0
2 changed files with 2 additions and 1 deletions

View File

@@ -1812,6 +1812,7 @@ class Unet(nn.Module):
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {text_encodings.shape} - required batch size is {batch_size}'
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
text_mask = torch.any(text_encodings != 0., dim = -1)