Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
dc50c6b34e allow for config driven creation of clip-less diffusion prior 2022-05-22 20:13:20 -07:00
5 changed files with 17 additions and 40 deletions

View File

@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lucidrains/imagen-pytorch">here</a>. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Status
@@ -26,7 +26,7 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
- Decoder 🚧
- DALL-E 2 🚧
## Install

View File

@@ -890,8 +890,6 @@ class DiffusionPrior(BaseGaussianDiffusion):
)
if exists(clip):
assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
@@ -1712,19 +1710,12 @@ class Decoder(BaseGaussianDiffusion):
)
self.unconditional = unconditional
# text conditioning
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
self.condition_on_text_encodings = condition_on_text_encodings
# clip
assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None
if exists(clip):
assert not unconditional, 'clip must not be given if doing unconditional image training'
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
@@ -1734,20 +1725,13 @@ class Decoder(BaseGaussianDiffusion):
assert isinstance(clip, BaseClipAdapter)
self.clip = clip
# determine image size, with image_size and image_sizes taking precedence
if exists(image_size) or exists(image_sizes):
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
image_size = default(image_size, lambda: image_sizes[-1])
elif exists(clip):
image_size = clip.image_size
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
else:
raise Error('either image_size, image_sizes, or clip must be given to decoder')
self.clip_image_size = default(image_size, lambda: image_sizes[-1])
self.channels = channels
# channels
self.channels = channels
self.condition_on_text_encodings = condition_on_text_encodings
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
@@ -1789,7 +1773,7 @@ class Decoder(BaseGaussianDiffusion):
# unet image sizes
image_sizes = default(image_sizes, (image_size,))
image_sizes = default(image_sizes, (self.clip_image_size,))
image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
@@ -1827,7 +1811,6 @@ class Decoder(BaseGaussianDiffusion):
self.clip_x_start = clip_x_start
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity

View File

@@ -68,7 +68,7 @@ class UnetConfig(BaseModel):
extra = "allow"
class DecoderConfig(BaseModel):
unets: ListOrTuple(UnetConfig)
unets: Union[List[UnetConfig], Tuple[UnetConfig]]
image_size: int = None
image_sizes: ListOrTuple(int) = None
channels: int = 3

View File

@@ -288,7 +288,7 @@ class DiffusionPriorTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0]))
def save(self, path, overwrite = True, **kwargs):
def save(self, path, overwrite = True):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
@@ -298,8 +298,7 @@ class DiffusionPriorTrainer(nn.Module):
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = get_pkg_version(),
step = self.step.item(),
**kwargs
step = self.step.item()
)
if self.use_ema:
@@ -320,7 +319,7 @@ class DiffusionPriorTrainer(nn.Module):
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return loaded_obj
return
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
@@ -329,8 +328,6 @@ class DiffusionPriorTrainer(nn.Module):
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
@@ -452,7 +449,7 @@ class DecoderTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True, **kwargs):
def save(self, path, overwrite = True):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
@@ -460,8 +457,7 @@ class DecoderTrainer(nn.Module):
save_obj = dict(
model = self.decoder.state_dict(),
version = get_pkg_version(),
step = self.step.item(),
**kwargs
step = self.step.item()
)
for ind in range(0, self.num_unets):
@@ -489,7 +485,7 @@ class DecoderTrainer(nn.Module):
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return loaded_obj
return
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
@@ -504,8 +500,6 @@ class DecoderTrainer(nn.Module):
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.4.14',
version = '0.4.9',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',