diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 23c3a7a..4907b75 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -890,6 +890,8 @@ class DiffusionPrior(BaseGaussianDiffusion): ) if exists(clip): + assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})' + if isinstance(clip, CLIP): clip = XClipAdapter(clip, **clip_adapter_overrides) elif isinstance(clip, CoCa): @@ -1721,6 +1723,7 @@ class Decoder(BaseGaussianDiffusion): self.clip = None if exists(clip): assert not unconditional, 'clip must not be given if doing unconditional image training' + assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})' if isinstance(clip, CLIP): clip = XClipAdapter(clip, **clip_adapter_overrides) diff --git a/setup.py b/setup.py index b83e048..c564511 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.4.11', + version = '0.4.14', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',