Compare commits

...

6 Commits

Author SHA1 Message Date
Phil Wang
a0e41267f8 just use an assert to make sure clip image channels is never different than the channels of the diffusion prior and decoder, if clip is given 2022-05-22 22:34:33 -07:00
Phil Wang
276abf337b fix and cleanup image size determination logic in decoder 2022-05-22 22:28:45 -07:00
Phil Wang
ae42d03006 allow for saving of additional fields on save method in trainers, and return loaded objects from the load method 2022-05-22 22:14:25 -07:00
Phil Wang
4d346e98d9 allow for config driven creation of clip-less diffusion prior 2022-05-22 20:36:20 -07:00
Phil Wang
2b1fd1ad2e product management 2022-05-22 19:23:40 -07:00
zion
82a2ef37d9 Update README.md (#109)
block in a section that links to available pre-trained models for those who are interested
2022-05-22 19:22:30 -07:00
5 changed files with 86 additions and 18 deletions

View File

@@ -24,6 +24,11 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
*ongoing at 21k steps*
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder 🚧
- DALL-E 2 🚧
## Install
```bash
@@ -1079,6 +1084,7 @@ This library would not have gotten to this working state without the help of
- [x] use pydantic for config drive training
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab

View File

@@ -890,6 +890,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
)
if exists(clip):
assert image_channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
@@ -1710,12 +1712,19 @@ class Decoder(BaseGaussianDiffusion):
)
self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
# text conditioning
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
self.condition_on_text_encodings = condition_on_text_encodings
# clip
self.clip = None
if exists(clip):
assert not unconditional, 'clip must not be given if doing unconditional image training'
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
@@ -1725,13 +1734,20 @@ class Decoder(BaseGaussianDiffusion):
assert isinstance(clip, BaseClipAdapter)
self.clip = clip
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
else:
self.clip_image_size = default(image_size, lambda: image_sizes[-1])
self.channels = channels
self.condition_on_text_encodings = condition_on_text_encodings
# determine image size, with image_size and image_sizes taking precedence
if exists(image_size) or exists(image_sizes):
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
image_size = default(image_size, lambda: image_sizes[-1])
elif exists(clip):
image_size = clip.image_size
else:
raise Error('either image_size, image_sizes, or clip must be given to decoder')
# channels
self.channels = channels
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
@@ -1773,7 +1789,7 @@ class Decoder(BaseGaussianDiffusion):
# unet image sizes
image_sizes = default(image_sizes, (self.clip_image_size,))
image_sizes = default(image_sizes, (image_size,))
image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
@@ -1811,6 +1827,7 @@ class Decoder(BaseGaussianDiffusion):
self.clip_x_start = clip_x_start
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity

View File

@@ -3,7 +3,7 @@ from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork
# helper functions
@@ -18,6 +18,43 @@ def ListOrTuple(inner_type):
# pydantic classes
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
num_timesteps: int = None
num_time_embeds: int = 1
num_image_embeds: int = 1
num_text_embeds: int = 1
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
norm_out: bool = True
attn_dropout: float = 0.
ff_dropout: float = 0.
final_proj: bool = True
normformer: bool = False
rotary_emb: bool = True
class DiffusionPriorConfig(BaseModel):
# only clip-less diffusion prior config for now
net: DiffusionPriorNetworkConfig
image_embed_dim: int
image_size: int
image_channels: int = 3
timesteps: int = 1000
cond_drop_prob: float = 0.
loss_type: str = 'l2'
predict_x_start: bool = True
beta_schedule: str = 'cosine'
def create(self):
kwargs = self.dict()
diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net'))
return DiffusionPrior(net = diffusion_prior_network, **kwargs)
class Config:
extra = "allow"
class UnetConfig(BaseModel):
dim: int
dim_mults: ListOrTuple(int)
@@ -31,7 +68,7 @@ class UnetConfig(BaseModel):
extra = "allow"
class DecoderConfig(BaseModel):
unets: Union[List[UnetConfig], Tuple[UnetConfig]]
unets: ListOrTuple(UnetConfig)
image_size: int = None
image_sizes: ListOrTuple(int) = None
channels: int = 3
@@ -39,6 +76,8 @@ class DecoderConfig(BaseModel):
loss_type: str = 'l2'
beta_schedule: str = 'cosine'
learned_variance: bool = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5
def create(self):
decoder_kwargs = self.dict()

View File

@@ -288,7 +288,7 @@ class DiffusionPriorTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0]))
def save(self, path, overwrite = True):
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
@@ -298,7 +298,8 @@ class DiffusionPriorTrainer(nn.Module):
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = get_pkg_version(),
step = self.step.item()
step = self.step.item(),
**kwargs
)
if self.use_ema:
@@ -319,7 +320,7 @@ class DiffusionPriorTrainer(nn.Module):
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return
return loaded_obj
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
@@ -328,6 +329,8 @@ class DiffusionPriorTrainer(nn.Module):
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
@@ -449,7 +452,7 @@ class DecoderTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True):
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
@@ -457,7 +460,8 @@ class DecoderTrainer(nn.Module):
save_obj = dict(
model = self.decoder.state_dict(),
version = get_pkg_version(),
step = self.step.item()
step = self.step.item(),
**kwargs
)
for ind in range(0, self.num_unets):
@@ -485,7 +489,7 @@ class DecoderTrainer(nn.Module):
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return
return loaded_obj
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
@@ -500,6 +504,8 @@ class DecoderTrainer(nn.Module):
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.4.8',
version = '0.4.12',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',