mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 13:44:21 +01:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0e41267f8 | ||
|
|
276abf337b | ||
|
|
ae42d03006 | ||
|
|
4d346e98d9 | ||
|
|
2b1fd1ad2e | ||
|
|
82a2ef37d9 |
@@ -24,6 +24,11 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
|
|||||||
|
|
||||||
*ongoing at 21k steps*
|
*ongoing at 21k steps*
|
||||||
|
|
||||||
|
## Pre-Trained Models
|
||||||
|
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||||
|
- Decoder 🚧
|
||||||
|
- DALL-E 2 🚧
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -1079,6 +1084,7 @@ This library would not have gotten to this working state without the help of
|
|||||||
- [x] use pydantic for config drive training
|
- [x] use pydantic for config drive training
|
||||||
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||||
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||||
|
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
|
|||||||
@@ -890,6 +890,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
|
assert image_channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||||
|
|
||||||
if isinstance(clip, CLIP):
|
if isinstance(clip, CLIP):
|
||||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||||
elif isinstance(clip, CoCa):
|
elif isinstance(clip, CoCa):
|
||||||
@@ -1710,12 +1712,19 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.unconditional = unconditional
|
self.unconditional = unconditional
|
||||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
|
||||||
|
|
||||||
assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
# text conditioning
|
||||||
|
|
||||||
|
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||||
|
self.condition_on_text_encodings = condition_on_text_encodings
|
||||||
|
|
||||||
|
# clip
|
||||||
|
|
||||||
self.clip = None
|
self.clip = None
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
|
assert not unconditional, 'clip must not be given if doing unconditional image training'
|
||||||
|
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||||
|
|
||||||
if isinstance(clip, CLIP):
|
if isinstance(clip, CLIP):
|
||||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||||
elif isinstance(clip, CoCa):
|
elif isinstance(clip, CoCa):
|
||||||
@@ -1725,13 +1734,20 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
assert isinstance(clip, BaseClipAdapter)
|
assert isinstance(clip, BaseClipAdapter)
|
||||||
|
|
||||||
self.clip = clip
|
self.clip = clip
|
||||||
self.clip_image_size = clip.image_size
|
|
||||||
self.channels = clip.image_channels
|
|
||||||
else:
|
|
||||||
self.clip_image_size = default(image_size, lambda: image_sizes[-1])
|
|
||||||
self.channels = channels
|
|
||||||
|
|
||||||
self.condition_on_text_encodings = condition_on_text_encodings
|
# determine image size, with image_size and image_sizes taking precedence
|
||||||
|
|
||||||
|
if exists(image_size) or exists(image_sizes):
|
||||||
|
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
|
||||||
|
image_size = default(image_size, lambda: image_sizes[-1])
|
||||||
|
elif exists(clip):
|
||||||
|
image_size = clip.image_size
|
||||||
|
else:
|
||||||
|
raise Error('either image_size, image_sizes, or clip must be given to decoder')
|
||||||
|
|
||||||
|
# channels
|
||||||
|
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
# automatically take care of ensuring that first unet is unconditional
|
# automatically take care of ensuring that first unet is unconditional
|
||||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||||
@@ -1773,7 +1789,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
# unet image sizes
|
# unet image sizes
|
||||||
|
|
||||||
image_sizes = default(image_sizes, (self.clip_image_size,))
|
image_sizes = default(image_sizes, (image_size,))
|
||||||
image_sizes = tuple(sorted(set(image_sizes)))
|
image_sizes = tuple(sorted(set(image_sizes)))
|
||||||
|
|
||||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||||
@@ -1811,6 +1827,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.clip_x_start = clip_x_start
|
self.clip_x_start = clip_x_start
|
||||||
|
|
||||||
# normalize and unnormalize image functions
|
# normalize and unnormalize image functions
|
||||||
|
|
||||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||||
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from torchvision import transforms as T
|
|||||||
from pydantic import BaseModel, validator, root_validator
|
from pydantic import BaseModel, validator, root_validator
|
||||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder
|
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
@@ -18,6 +18,43 @@ def ListOrTuple(inner_type):
|
|||||||
|
|
||||||
# pydantic classes
|
# pydantic classes
|
||||||
|
|
||||||
|
class DiffusionPriorNetworkConfig(BaseModel):
|
||||||
|
dim: int
|
||||||
|
depth: int
|
||||||
|
num_timesteps: int = None
|
||||||
|
num_time_embeds: int = 1
|
||||||
|
num_image_embeds: int = 1
|
||||||
|
num_text_embeds: int = 1
|
||||||
|
dim_head: int = 64
|
||||||
|
heads: int = 8
|
||||||
|
ff_mult: int = 4
|
||||||
|
norm_out: bool = True
|
||||||
|
attn_dropout: float = 0.
|
||||||
|
ff_dropout: float = 0.
|
||||||
|
final_proj: bool = True
|
||||||
|
normformer: bool = False
|
||||||
|
rotary_emb: bool = True
|
||||||
|
|
||||||
|
class DiffusionPriorConfig(BaseModel):
|
||||||
|
# only clip-less diffusion prior config for now
|
||||||
|
net: DiffusionPriorNetworkConfig
|
||||||
|
image_embed_dim: int
|
||||||
|
image_size: int
|
||||||
|
image_channels: int = 3
|
||||||
|
timesteps: int = 1000
|
||||||
|
cond_drop_prob: float = 0.
|
||||||
|
loss_type: str = 'l2'
|
||||||
|
predict_x_start: bool = True
|
||||||
|
beta_schedule: str = 'cosine'
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
kwargs = self.dict()
|
||||||
|
diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net'))
|
||||||
|
return DiffusionPrior(net = diffusion_prior_network, **kwargs)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
class UnetConfig(BaseModel):
|
class UnetConfig(BaseModel):
|
||||||
dim: int
|
dim: int
|
||||||
dim_mults: ListOrTuple(int)
|
dim_mults: ListOrTuple(int)
|
||||||
@@ -31,7 +68,7 @@ class UnetConfig(BaseModel):
|
|||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
class DecoderConfig(BaseModel):
|
class DecoderConfig(BaseModel):
|
||||||
unets: Union[List[UnetConfig], Tuple[UnetConfig]]
|
unets: ListOrTuple(UnetConfig)
|
||||||
image_size: int = None
|
image_size: int = None
|
||||||
image_sizes: ListOrTuple(int) = None
|
image_sizes: ListOrTuple(int) = None
|
||||||
channels: int = 3
|
channels: int = 3
|
||||||
@@ -39,6 +76,8 @@ class DecoderConfig(BaseModel):
|
|||||||
loss_type: str = 'l2'
|
loss_type: str = 'l2'
|
||||||
beta_schedule: str = 'cosine'
|
beta_schedule: str = 'cosine'
|
||||||
learned_variance: bool = True
|
learned_variance: bool = True
|
||||||
|
image_cond_drop_prob: float = 0.1
|
||||||
|
text_cond_drop_prob: float = 0.5
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
decoder_kwargs = self.dict()
|
decoder_kwargs = self.dict()
|
||||||
|
|||||||
@@ -288,7 +288,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0]))
|
self.register_buffer('step', torch.tensor([0]))
|
||||||
|
|
||||||
def save(self, path, overwrite = True):
|
def save(self, path, overwrite = True, **kwargs):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
assert not (path.exists() and not overwrite)
|
assert not (path.exists() and not overwrite)
|
||||||
path.parent.mkdir(parents = True, exist_ok = True)
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
@@ -298,7 +298,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
optimizer = self.optimizer.state_dict(),
|
optimizer = self.optimizer.state_dict(),
|
||||||
model = self.diffusion_prior.state_dict(),
|
model = self.diffusion_prior.state_dict(),
|
||||||
version = get_pkg_version(),
|
version = get_pkg_version(),
|
||||||
step = self.step.item()
|
step = self.step.item(),
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
@@ -319,7 +320,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|
||||||
if only_model:
|
if only_model:
|
||||||
return
|
return loaded_obj
|
||||||
|
|
||||||
self.scaler.load_state_dict(loaded_obj['scaler'])
|
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||||
@@ -328,6 +329,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
assert 'ema' in loaded_obj
|
assert 'ema' in loaded_obj
|
||||||
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
|
|
||||||
|
return loaded_obj
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
if exists(self.max_grad_norm):
|
if exists(self.max_grad_norm):
|
||||||
self.scaler.unscale_(self.optimizer)
|
self.scaler.unscale_(self.optimizer)
|
||||||
@@ -449,7 +452,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
def save(self, path, overwrite = True):
|
def save(self, path, overwrite = True, **kwargs):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
assert not (path.exists() and not overwrite)
|
assert not (path.exists() and not overwrite)
|
||||||
path.parent.mkdir(parents = True, exist_ok = True)
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
@@ -457,7 +460,8 @@ class DecoderTrainer(nn.Module):
|
|||||||
save_obj = dict(
|
save_obj = dict(
|
||||||
model = self.decoder.state_dict(),
|
model = self.decoder.state_dict(),
|
||||||
version = get_pkg_version(),
|
version = get_pkg_version(),
|
||||||
step = self.step.item()
|
step = self.step.item(),
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
for ind in range(0, self.num_unets):
|
for ind in range(0, self.num_unets):
|
||||||
@@ -485,7 +489,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|
||||||
if only_model:
|
if only_model:
|
||||||
return
|
return loaded_obj
|
||||||
|
|
||||||
for ind in range(0, self.num_unets):
|
for ind in range(0, self.num_unets):
|
||||||
scaler_key = f'scaler{ind}'
|
scaler_key = f'scaler{ind}'
|
||||||
@@ -500,6 +504,8 @@ class DecoderTrainer(nn.Module):
|
|||||||
assert 'ema' in loaded_obj
|
assert 'ema' in loaded_obj
|
||||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
|
|
||||||
|
return loaded_obj
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unets(self):
|
def unets(self):
|
||||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||||
|
|||||||
Reference in New Issue
Block a user