mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 03:54:35 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
72bf159331 | ||
|
|
e5e47cfecb | ||
|
|
fa533962bd | ||
|
|
276abf337b |
@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
|
||||
|
||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lucidrains/imagen-pytorch">here</a>. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
|
||||
|
||||
## Status
|
||||
|
||||
@@ -26,7 +26,7 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
|
||||
|
||||
## Pre-Trained Models
|
||||
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||
- Decoder 🚧
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
||||
- DALL-E 2 🚧
|
||||
|
||||
## Install
|
||||
|
||||
@@ -890,6 +890,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
)
|
||||
|
||||
if exists(clip):
|
||||
assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
@@ -1710,12 +1712,19 @@ class Decoder(BaseGaussianDiffusion):
|
||||
)
|
||||
|
||||
self.unconditional = unconditional
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
|
||||
assert self.unconditional or (exists(clip) or exists(image_size) or exists(image_sizes)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
# text conditioning
|
||||
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# clip
|
||||
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
assert not unconditional, 'clip must not be given if doing unconditional image training'
|
||||
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
@@ -1725,13 +1734,20 @@ class Decoder(BaseGaussianDiffusion):
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
|
||||
self.clip = clip
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
else:
|
||||
self.clip_image_size = default(image_size, lambda: image_sizes[-1])
|
||||
self.channels = channels
|
||||
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
# determine image size, with image_size and image_sizes taking precedence
|
||||
|
||||
if exists(image_size) or exists(image_sizes):
|
||||
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
|
||||
image_size = default(image_size, lambda: image_sizes[-1])
|
||||
elif exists(clip):
|
||||
image_size = clip.image_size
|
||||
else:
|
||||
raise Error('either image_size, image_sizes, or clip must be given to decoder')
|
||||
|
||||
# channels
|
||||
|
||||
self.channels = channels
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
@@ -1773,7 +1789,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
# unet image sizes
|
||||
|
||||
image_sizes = default(image_sizes, (self.clip_image_size,))
|
||||
image_sizes = default(image_sizes, (image_size,))
|
||||
image_sizes = tuple(sorted(set(image_sizes)))
|
||||
|
||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||
@@ -1811,6 +1827,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.clip_x_start = clip_x_start
|
||||
|
||||
# normalize and unnormalize image functions
|
||||
|
||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||
|
||||
|
||||
Reference in New Issue
Block a user