From ba58ae0bf2416a11e522d7f3afd604a9c5d488c6 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 28 Aug 2022 10:11:37 -0700 Subject: [PATCH] add two asserts to diffusion prior to ensure matching image embedding dimensions for clip, diffusion prior network, and what was set on diffusion prior --- dalle2_pytorch/dalle2_pytorch.py | 4 ++++ dalle2_pytorch/version.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index e87788d..cf3217c 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1166,6 +1166,10 @@ class DiffusionPrior(nn.Module): self.net = net self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent) + + assert net.dim == self.image_embed_dim, f'your diffusion prior network has a dimension of {net.dim}, but you set your image embedding dimension (keyword image_embed_dim) on DiffusionPrior to {self.image_embed_dim}' + assert not exists(clip) or clip.dim_latent == self.image_embed_dim, f'you passed in a CLIP to the diffusion prior with latent dimensions of {clip.dim_latent}, but your image embedding dimension (keyword image_embed_dim) for the DiffusionPrior was set to {self.image_embed_dim}' + self.channels = default(image_channels, lambda: clip.image_channels) self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 52af183..ff987d2 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.10.0' +__version__ = '1.10.1'