diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index e87788d..cf3217c 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1166,6 +1166,10 @@ class DiffusionPrior(nn.Module): self.net = net self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent) + + assert net.dim == self.image_embed_dim, f'your diffusion prior network has a dimension of {net.dim}, but you set your image embedding dimension (keyword image_embed_dim) on DiffusionPrior to {self.image_embed_dim}' + assert not exists(clip) or clip.dim_latent == self.image_embed_dim, f'you passed in a CLIP to the diffusion prior with latent dimensions of {clip.dim_latent}, but your image embedding dimension (keyword image_embed_dim) for the DiffusionPrior was set to {self.image_embed_dim}' + self.channels = default(image_channels, lambda: clip.image_channels) self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 52af183..ff987d2 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.10.0' +__version__ = '1.10.1'