From dbf4a281f190514f26a3e723bd726c0ff0c0985b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 27 Apr 2022 20:45:27 -0700 Subject: [PATCH] make sure another CLIP can actually be passed in, as long as it is wrapped in an adapter extended from BaseClipAdapter --- dalle2_pytorch/dalle2_pytorch.py | 9 +++++++-- setup.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index bbb9a53..be861c9 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -647,9 +647,12 @@ class DiffusionPrior(BaseGaussianDiffusion): ) if exists(clip): - assert isinstance(clip, CLIP) + if isinstance(clip, CLIP): + clip = XClipAdapter(clip) + + assert isinstance(clip, BaseClipAdapter) freeze_model_and_make_eval_(clip) - self.clip = XClipAdapter(clip) + self.clip = clip else: assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given' self.clip = None @@ -1248,6 +1251,8 @@ class Decoder(BaseGaussianDiffusion): clip = XClipAdapter(clip) freeze_model_and_make_eval_(clip) + assert isinstance(clip, BaseClipAdapter) + self.clip = clip self.clip_image_size = clip.image_size self.channels = clip.image_channels diff --git a/setup.py b/setup.py index 636e449..894f64d 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.60', + version = '0.0.61', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',