mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 00:34:19 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0e41267f8 |
@@ -890,6 +890,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
|
assert image_channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||||
|
|
||||||
if isinstance(clip, CLIP):
|
if isinstance(clip, CLIP):
|
||||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||||
elif isinstance(clip, CoCa):
|
elif isinstance(clip, CoCa):
|
||||||
@@ -1721,6 +1723,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.clip = None
|
self.clip = None
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
assert not unconditional, 'clip must not be given if doing unconditional image training'
|
assert not unconditional, 'clip must not be given if doing unconditional image training'
|
||||||
|
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||||
|
|
||||||
if isinstance(clip, CLIP):
|
if isinstance(clip, CLIP):
|
||||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||||
|
|||||||
Reference in New Issue
Block a user