mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 13:54:29 +01:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8422ffd5d | ||
|
|
2aadc23c7c | ||
|
|
c098f57e09 | ||
|
|
0021535c26 | ||
|
|
56883910fb | ||
|
|
893f270012 |
42
README.md
42
README.md
@@ -1017,33 +1017,6 @@ The most significant parameters for the script are as follows:
|
|||||||
|
|
||||||
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
||||||
|
|
||||||
#### Loading and Saving the DiffusionPrior model
|
|
||||||
|
|
||||||
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
|
||||||
```
|
|
||||||
|
|
||||||
##### Loading
|
|
||||||
|
|
||||||
load_diffusion_model(dprior_path, device)
|
|
||||||
dprior_path : path to saved model(.pth)
|
|
||||||
device : the cuda device you're running on
|
|
||||||
|
|
||||||
##### Saving
|
|
||||||
|
|
||||||
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
|
|
||||||
save_path : path to save at
|
|
||||||
model : object of Diffusion_Prior
|
|
||||||
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
|
|
||||||
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
|
||||||
scaler : a GradScaler object.
|
|
||||||
e.g: scaler = GradScaler(enabled=amp)
|
|
||||||
config : config object created in train_diffusion_prior.py - see file for example.
|
|
||||||
image_embed_dim - the dimension of the image_embedding
|
|
||||||
e.g: 768
|
|
||||||
|
|
||||||
## CLI (wip)
|
## CLI (wip)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -1092,19 +1065,14 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||||
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||||
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
|
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
|
||||||
|
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
|
||||||
|
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
|
||||||
|
- [x] allow for unet to be able to condition non-cross attention style as well
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
|
||||||
- [ ] train on a toy task, offer in colab
|
|
||||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
|
||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
|
||||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
|
|
||||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
- [ ] build infilling
|
||||||
- [ ] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
|
||||||
- [ ] decoder needs one day worth of refactor for tech debt
|
|
||||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
|
||||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
"channels": 3,
|
"channels": 3,
|
||||||
"timesteps": 1000,
|
"timesteps": 1000,
|
||||||
"loss_type": "l2",
|
"loss_type": "l2",
|
||||||
"beta_schedule": "cosine",
|
"beta_schedule": ["cosine"],
|
||||||
"learned_variance": true
|
"learned_variance": true
|
||||||
},
|
},
|
||||||
"data": {
|
"data": {
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ from dalle2_pytorch.optimizer import get_optimizer
|
|||||||
from dalle2_pytorch.version import __version__
|
from dalle2_pytorch.version import __version__
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from ema_pytorch import EMA
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -62,16 +64,6 @@ def num_to_groups(num, divisor):
|
|||||||
arr.append(remainder)
|
arr.append(remainder)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
def clamp(value, min_value = None, max_value = None):
|
|
||||||
assert exists(min_value) or exists(max_value)
|
|
||||||
if exists(min_value):
|
|
||||||
value = max(value, min_value)
|
|
||||||
|
|
||||||
if exists(max_value):
|
|
||||||
value = min(value, max_value)
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
# decorators
|
# decorators
|
||||||
|
|
||||||
def cast_torch_tensor(fn):
|
def cast_torch_tensor(fn):
|
||||||
@@ -145,146 +137,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
|||||||
chunk_size_frac = chunk_size / batch_size
|
chunk_size_frac = chunk_size / batch_size
|
||||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||||
|
|
||||||
# saving and loading functions
|
|
||||||
|
|
||||||
# for diffusion prior
|
|
||||||
|
|
||||||
def load_diffusion_model(dprior_path, device):
|
|
||||||
dprior_path = Path(dprior_path)
|
|
||||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
|
||||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
|
||||||
|
|
||||||
# Get hyperparameters of loaded model
|
|
||||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
|
||||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
|
||||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
|
||||||
|
|
||||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
|
||||||
|
|
||||||
# DiffusionPriorNetwork
|
|
||||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
|
||||||
|
|
||||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
|
||||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
|
||||||
|
|
||||||
# Load state dict from saved model
|
|
||||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
|
||||||
|
|
||||||
return diffusion_prior, loaded_obj
|
|
||||||
|
|
||||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
|
||||||
# Saving State Dict
|
|
||||||
print_ribbon('Saving checkpoint')
|
|
||||||
|
|
||||||
state_dict = dict(model=model.state_dict(),
|
|
||||||
optimizer=optimizer.state_dict(),
|
|
||||||
scaler=scaler.state_dict(),
|
|
||||||
hparams = config,
|
|
||||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
|
||||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
|
||||||
|
|
||||||
# exponential moving average wrapper
|
|
||||||
|
|
||||||
class EMA(nn.Module):
|
|
||||||
"""
|
|
||||||
Implements exponential moving average shadowing for your model.
|
|
||||||
|
|
||||||
Utilizes an inverse decay schedule to manage longer term training runs.
|
|
||||||
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
|
|
||||||
|
|
||||||
@crowsonkb's notes on EMA Warmup:
|
|
||||||
|
|
||||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
|
|
||||||
good values for models you plan to train for a million or more steps (reaches decay
|
|
||||||
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
|
|
||||||
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
|
||||||
215.4k steps).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
|
||||||
power (float): Exponential factor of EMA warmup. Default: 1.
|
|
||||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
beta = 0.9999,
|
|
||||||
update_after_step = 100,
|
|
||||||
update_every = 10,
|
|
||||||
inv_gamma = 1.0,
|
|
||||||
power = 2/3,
|
|
||||||
min_value = 0.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.beta = beta
|
|
||||||
self.online_model = model
|
|
||||||
self.ema_model = copy.deepcopy(model)
|
|
||||||
|
|
||||||
self.update_every = update_every
|
|
||||||
self.update_after_step = update_after_step
|
|
||||||
|
|
||||||
self.inv_gamma = inv_gamma
|
|
||||||
self.power = power
|
|
||||||
self.min_value = min_value
|
|
||||||
|
|
||||||
self.register_buffer('initted', torch.Tensor([False]))
|
|
||||||
self.register_buffer('step', torch.tensor([0]))
|
|
||||||
|
|
||||||
def restore_ema_model_device(self):
|
|
||||||
device = self.initted.device
|
|
||||||
self.ema_model.to(device)
|
|
||||||
|
|
||||||
def copy_params_from_model_to_ema(self):
|
|
||||||
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
|
||||||
ma_param.data.copy_(current_param.data)
|
|
||||||
|
|
||||||
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
|
|
||||||
ma_buffer.data.copy_(current_buffer.data)
|
|
||||||
|
|
||||||
def get_current_decay(self):
|
|
||||||
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
|
|
||||||
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
|
||||||
|
|
||||||
if epoch <= 0:
|
|
||||||
return 0.
|
|
||||||
|
|
||||||
return clamp(value, min_value = self.min_value, max_value = self.beta)
|
|
||||||
|
|
||||||
def update(self):
|
|
||||||
step = self.step.item()
|
|
||||||
self.step += 1
|
|
||||||
|
|
||||||
if (step % self.update_every) != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
if step <= self.update_after_step:
|
|
||||||
self.copy_params_from_model_to_ema()
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.initted.item():
|
|
||||||
self.copy_params_from_model_to_ema()
|
|
||||||
self.initted.data.copy_(torch.Tensor([True]))
|
|
||||||
|
|
||||||
self.update_moving_average(self.ema_model, self.online_model)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def update_moving_average(self, ma_model, current_model):
|
|
||||||
current_decay = self.get_current_decay()
|
|
||||||
|
|
||||||
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
|
|
||||||
difference = ma_params.data - current_params.data
|
|
||||||
difference.mul_(1.0 - current_decay)
|
|
||||||
ma_params.sub_(difference)
|
|
||||||
|
|
||||||
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
|
|
||||||
difference = ma_buffer - current_buffer
|
|
||||||
difference.mul_(1.0 - current_decay)
|
|
||||||
ma_buffer.sub_(difference)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
return self.ema_model(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# diffusion prior trainer
|
# diffusion prior trainer
|
||||||
|
|
||||||
def prior_sample_in_chunks(fn):
|
def prior_sample_in_chunks(fn):
|
||||||
@@ -505,26 +357,20 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
@prior_sample_in_chunks
|
@prior_sample_in_chunks
|
||||||
def p_sample_loop(self, *args, **kwargs):
|
def p_sample_loop(self, *args, **kwargs):
|
||||||
if self.use_ema:
|
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
return model.p_sample_loop(*args, **kwargs)
|
||||||
else:
|
|
||||||
return self.diffusion_prior.p_sample_loop(*args, **kwargs)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
@prior_sample_in_chunks
|
@prior_sample_in_chunks
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
if self.use_ema:
|
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
return model.sample(*args, **kwargs)
|
||||||
else:
|
|
||||||
return self.diffusion_prior.sample(*args, **kwargs)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_batch_size(self, *args, **kwargs):
|
def sample_batch_size(self, *args, **kwargs):
|
||||||
if self.use_ema:
|
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
return model.sample_batch_size(*args, **kwargs)
|
||||||
else:
|
|
||||||
return self.diffusion_prior.sample_batch_size(*args, **kwargs)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.11.2'
|
__version__ = '0.11.4'
|
||||||
|
|||||||
@@ -16,10 +16,11 @@ from torchvision.utils import make_grid, save_image
|
|||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from dalle2_pytorch.train import EMA
|
|
||||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
|
|
||||||
|
from ema_pytorch import EMA
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
@@ -97,7 +98,7 @@ class VQGanVAETrainer(nn.Module):
|
|||||||
valid_frac = 0.05,
|
valid_frac = 0.05,
|
||||||
random_split_seed = 42,
|
random_split_seed = 42,
|
||||||
ema_beta = 0.995,
|
ema_beta = 0.995,
|
||||||
ema_update_after_step = 2000,
|
ema_update_after_step = 500,
|
||||||
ema_update_every = 10,
|
ema_update_every = 10,
|
||||||
apply_grad_penalty_every = 4,
|
apply_grad_penalty_every = 4,
|
||||||
amp = False
|
amp = False
|
||||||
|
|||||||
Reference in New Issue
Block a user