mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
make sure text encodings being passed in has the correct batch dimension
This commit is contained in:
@@ -1812,6 +1812,7 @@ class Unet(nn.Module):
|
|||||||
text_tokens = None
|
text_tokens = None
|
||||||
|
|
||||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||||
|
assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {text_encodings.shape} - required batch size is {batch_size}'
|
||||||
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
||||||
|
|
||||||
text_mask = torch.any(text_encodings != 0., dim = -1)
|
text_mask = torch.any(text_encodings != 0., dim = -1)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.22.0'
|
__version__ = '0.22.1'
|
||||||
|
|||||||
Reference in New Issue
Block a user