just use an assert to make sure clip image channels is never different than the channels of the diffusion prior and decoder, if clip is given

This commit is contained in:
Phil Wang
2022-05-22 22:43:14 -07:00
parent 276abf337b
commit fa533962bd
2 changed files with 4 additions and 1 deletions

View File

@@ -890,6 +890,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
)
if exists(clip):
assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
@@ -1721,6 +1723,7 @@ class Decoder(BaseGaussianDiffusion):
self.clip = None
if exists(clip):
assert not unconditional, 'clip must not be given if doing unconditional image training'
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides)