mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1e7b5f6bb | ||
|
|
10b905b445 | ||
|
|
9b322ea634 | ||
|
|
ba64ea45cc | ||
|
|
64f7be1926 |
@@ -933,7 +933,7 @@ Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/r
|
||||
|
||||
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
||||
|
||||
## from dalle2_pytorch import load_diffusion_model, save_diffusion_model
|
||||
## from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
||||
|
||||
load_diffusion_model(dprior_path, device)
|
||||
|
||||
@@ -999,6 +999,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
|
||||
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
|
||||
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
|
||||
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
@@ -1012,7 +1013,6 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder,load_diffusion_model,save_diffusion_model
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from functools import partial
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -34,42 +33,6 @@ from rotary_embedding_torch import RotaryEmbedding
|
||||
from x_clip import CLIP
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
# Diffusion Prior model loading and saving functions
|
||||
|
||||
def load_diffusion_model(dprior_path, device ):
|
||||
|
||||
dprior_path = Path(dprior_path)
|
||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
||||
|
||||
# Get hyperparameters of loaded model
|
||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
||||
|
||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
||||
|
||||
# Load state dict from saved model
|
||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||
|
||||
return diffusion_prior
|
||||
|
||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||
# Saving State Dict
|
||||
print("====================================== Saving checkpoint ======================================")
|
||||
state_dict = dict(model=model.state_dict(),
|
||||
optimizer=optimizer.state_dict(),
|
||||
scaler=scaler.state_dict(),
|
||||
hparams = config,
|
||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
@@ -1288,8 +1251,7 @@ class Unet(nn.Module):
|
||||
cond_on_image_embeds = False,
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
block_type = 'resnet',
|
||||
block_resnet_groups = 8,
|
||||
resnet_groups = 8,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1367,7 +1329,9 @@ class Unet(nn.Module):
|
||||
|
||||
# resnet block klass
|
||||
|
||||
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
|
||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
|
||||
# layers
|
||||
|
||||
@@ -1375,38 +1339,39 @@ class Unet(nn.Module):
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
|
||||
is_first = ind == 0
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
|
||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
|
||||
is_last = ind >= (num_resolutions - 2)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
out_dim = default(out_dim, channels)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
block_klass(dim, dim),
|
||||
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
@@ -39,6 +40,50 @@ def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
flank = symbol * repeat
|
||||
return f'{flank} {s} {flank}'
|
||||
|
||||
# saving and loading functions
|
||||
|
||||
# for diffusion prior
|
||||
|
||||
def load_diffusion_model(dprior_path, device):
|
||||
dprior_path = Path(dprior_path)
|
||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
||||
|
||||
# Get hyperparameters of loaded model
|
||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
||||
|
||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
||||
|
||||
# Load state dict from saved model
|
||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||
|
||||
return diffusion_prior
|
||||
|
||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||
# Saving State Dict
|
||||
print_ribbon('Saving checkpoint')
|
||||
|
||||
state_dict = dict(model=model.state_dict(),
|
||||
optimizer=optimizer.state_dict(),
|
||||
scaler=scaler.state_dict(),
|
||||
hparams = config,
|
||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
|
||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.2.2',
|
||||
version = '0.2.5',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -6,7 +6,8 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from embedding_reader import EmbeddingReader
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, load_diffusion_model, save_diffusion_model
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from torch.cuda.amp import autocast,GradScaler
|
||||
|
||||
@@ -153,7 +154,7 @@ def train(image_embed_dim,
|
||||
os.makedirs(save_path)
|
||||
|
||||
# Get image and text embeddings from the servers
|
||||
print("==============Downloading embeddings - image and text====================")
|
||||
print_ribbon("Downloading embeddings - image and text")
|
||||
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
||||
num_data_points = text_reader.count
|
||||
@@ -341,7 +342,7 @@ def main():
|
||||
RESUME,
|
||||
DPRIOR_PATH,
|
||||
config,
|
||||
atgs.wandb_entity,
|
||||
args.wandb_entity,
|
||||
args.wandb_project,
|
||||
args.learning_rate,
|
||||
args.max_grad_norm,
|
||||
|
||||
Reference in New Issue
Block a user