mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 16:14:26 +01:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b8af2210df | ||
|
|
f4fe6c570d | ||
|
|
645e207441 | ||
|
|
00743b3a0b | ||
|
|
01589aff6a | ||
|
|
7ecfd76cc0 | ||
|
|
6161b61c55 | ||
|
|
1ed0f9d80b | ||
|
|
f326a95e26 | ||
|
|
d7a0a2ce4b | ||
|
|
f23fab7ef7 | ||
|
|
857b9fbf1e |
@@ -24,6 +24,8 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
|
|||||||
|
|
||||||
*ongoing at 21k steps*
|
*ongoing at 21k steps*
|
||||||
|
|
||||||
|
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
|
||||||
|
|
||||||
## Pre-Trained Models
|
## Pre-Trained Models
|
||||||
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
||||||
@@ -1048,6 +1050,7 @@ This library would not have gotten to this working state without the help of
|
|||||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||||
|
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||||
|
|
||||||
... and many others. Thank you! 🙏
|
... and many others. Thank you! 🙏
|
||||||
|
|
||||||
@@ -1140,7 +1143,7 @@ This library would not have gotten to this working state without the help of
|
|||||||
```bibtex
|
```bibtex
|
||||||
@inproceedings{Tu2022MaxViTMV,
|
@inproceedings{Tu2022MaxViTMV,
|
||||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||||
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||||
year = {2022}
|
year = {2022}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@@ -1196,7 +1199,7 @@ This library would not have gotten to this working state without the help of
|
|||||||
```
|
```
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@misc{Saharia2022,
|
@misc{Saharia2022,https://stability.ai/
|
||||||
title = {Imagen: unprecedented photorealism × deep level of language understanding},
|
title = {Imagen: unprecedented photorealism × deep level of language understanding},
|
||||||
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
|
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
|
||||||
year = {2022}
|
year = {2022}
|
||||||
|
|||||||
70
configs/train_prior_config.example.json
Normal file
70
configs/train_prior_config.example.json
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
{
|
||||||
|
"prior": {
|
||||||
|
"clip": {
|
||||||
|
"make": "x-clip",
|
||||||
|
"model": "ViT-L/14",
|
||||||
|
"base_model_kwargs": {
|
||||||
|
"dim_text": 768,
|
||||||
|
"dim_image": 768,
|
||||||
|
"dim_latent": 768
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"net": {
|
||||||
|
"dim": 768,
|
||||||
|
"depth": 12,
|
||||||
|
"num_timesteps": 1000,
|
||||||
|
"num_time_embeds": 1,
|
||||||
|
"num_image_embeds": 1,
|
||||||
|
"num_text_embeds": 1,
|
||||||
|
"dim_head": 64,
|
||||||
|
"heads": 12,
|
||||||
|
"ff_mult": 4,
|
||||||
|
"norm_out": true,
|
||||||
|
"attn_dropout": 0.0,
|
||||||
|
"ff_dropout": 0.0,
|
||||||
|
"final_proj": true,
|
||||||
|
"normformer": true,
|
||||||
|
"rotary_emb": true
|
||||||
|
},
|
||||||
|
"image_embed_dim": 768,
|
||||||
|
"image_size": 224,
|
||||||
|
"image_channels": 3,
|
||||||
|
"timesteps": 1000,
|
||||||
|
"cond_drop_prob": 0.1,
|
||||||
|
"loss_type": "l2",
|
||||||
|
"predict_x_start": true,
|
||||||
|
"beta_schedule": "cosine",
|
||||||
|
"condition_on_text_encodings": true
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
|
||||||
|
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
|
||||||
|
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
|
||||||
|
"batch_size": 256,
|
||||||
|
"splits": {
|
||||||
|
"train": 0.9,
|
||||||
|
"val": 1e-7,
|
||||||
|
"test": 0.0999999
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"train": {
|
||||||
|
"epochs": 1,
|
||||||
|
"lr": 1.1e-4,
|
||||||
|
"wd": 6.02e-2,
|
||||||
|
"max_grad_norm": 0.5,
|
||||||
|
"use_ema": true,
|
||||||
|
"amp": false,
|
||||||
|
"save_every": 10000
|
||||||
|
},
|
||||||
|
"load": {
|
||||||
|
"source": null,
|
||||||
|
"resume": false
|
||||||
|
},
|
||||||
|
"tracker": {
|
||||||
|
"tracker_type": "wandb",
|
||||||
|
"data_path": "./prior_checkpoints",
|
||||||
|
"wandb_entity": "laion",
|
||||||
|
"wandb_project": "diffusion-prior",
|
||||||
|
"verbose": true
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1107,13 +1107,20 @@ class Block(nn.Module):
|
|||||||
groups = 8
|
groups = 8
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.block = nn.Sequential(
|
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
self.norm = nn.GroupNorm(groups, dim_out)
|
||||||
nn.GroupNorm(groups, dim_out),
|
self.act = nn.SiLU()
|
||||||
nn.SiLU()
|
|
||||||
)
|
def forward(self, x, scale_shift = None):
|
||||||
def forward(self, x):
|
x = self.project(x)
|
||||||
return self.block(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
|
if exists(scale_shift):
|
||||||
|
scale, shift = scale_shift
|
||||||
|
x = x * (scale + 1) + shift
|
||||||
|
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
||||||
|
|
||||||
class ResnetBlock(nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1132,7 +1139,7 @@ class ResnetBlock(nn.Module):
|
|||||||
if exists(time_cond_dim):
|
if exists(time_cond_dim):
|
||||||
self.time_mlp = nn.Sequential(
|
self.time_mlp = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(time_cond_dim, dim_out)
|
nn.Linear(time_cond_dim, dim_out * 2)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cross_attn = None
|
self.cross_attn = None
|
||||||
@@ -1152,11 +1159,14 @@ class ResnetBlock(nn.Module):
|
|||||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, cond = None, time_emb = None):
|
def forward(self, x, cond = None, time_emb = None):
|
||||||
h = self.block1(x)
|
|
||||||
|
|
||||||
|
scale_shift = None
|
||||||
if exists(self.time_mlp) and exists(time_emb):
|
if exists(self.time_mlp) and exists(time_emb):
|
||||||
time_emb = self.time_mlp(time_emb)
|
time_emb = self.time_mlp(time_emb)
|
||||||
h = rearrange(time_emb, 'b c -> b c 1 1') + h
|
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
|
||||||
|
scale_shift = time_emb.chunk(2, dim = 1)
|
||||||
|
|
||||||
|
h = self.block1(x, scale_shift = scale_shift)
|
||||||
|
|
||||||
if exists(self.cross_attn):
|
if exists(self.cross_attn):
|
||||||
assert exists(cond)
|
assert exists(cond)
|
||||||
@@ -1336,6 +1346,7 @@ class Unet(nn.Module):
|
|||||||
init_dim = None,
|
init_dim = None,
|
||||||
init_conv_kernel_size = 7,
|
init_conv_kernel_size = 7,
|
||||||
resnet_groups = 8,
|
resnet_groups = 8,
|
||||||
|
num_resnet_blocks = 1,
|
||||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||||
cross_embed_downsample = False,
|
cross_embed_downsample = False,
|
||||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||||
@@ -1421,6 +1432,7 @@ class Unet(nn.Module):
|
|||||||
# resnet block klass
|
# resnet block klass
|
||||||
|
|
||||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||||
|
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
||||||
|
|
||||||
assert len(resnet_groups) == len(in_out)
|
assert len(resnet_groups) == len(in_out)
|
||||||
|
|
||||||
@@ -1436,7 +1448,7 @@ class Unet(nn.Module):
|
|||||||
self.ups = nn.ModuleList([])
|
self.ups = nn.ModuleList([])
|
||||||
num_resolutions = len(in_out)
|
num_resolutions = len(in_out)
|
||||||
|
|
||||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
||||||
is_first = ind == 0
|
is_first = ind == 0
|
||||||
is_last = ind >= (num_resolutions - 1)
|
is_last = ind >= (num_resolutions - 1)
|
||||||
layer_cond_dim = cond_dim if not is_first else None
|
layer_cond_dim = cond_dim if not is_first else None
|
||||||
@@ -1444,7 +1456,7 @@ class Unet(nn.Module):
|
|||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||||
downsample_klass(dim_out) if not is_last else nn.Identity()
|
downsample_klass(dim_out) if not is_last else nn.Identity()
|
||||||
]))
|
]))
|
||||||
|
|
||||||
@@ -1454,14 +1466,14 @@ class Unet(nn.Module):
|
|||||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||||
|
|
||||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
||||||
is_last = ind >= (num_resolutions - 2)
|
is_last = ind >= (num_resolutions - 2)
|
||||||
layer_cond_dim = cond_dim if not is_last else None
|
layer_cond_dim = cond_dim if not is_last else None
|
||||||
|
|
||||||
self.ups.append(nn.ModuleList([
|
self.ups.append(nn.ModuleList([
|
||||||
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||||
Upsample(dim_in)
|
Upsample(dim_in)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
@@ -1618,10 +1630,13 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
|
|
||||||
for block1, sparse_attn, block2, downsample in self.downs:
|
for init_block, sparse_attn, resnet_blocks, downsample in self.downs:
|
||||||
x = block1(x, c, t)
|
x = init_block(x, c, t)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
x = block2(x, c, t)
|
|
||||||
|
for resnet_block in resnet_blocks:
|
||||||
|
x = resnet_block(x, c, t)
|
||||||
|
|
||||||
hiddens.append(x)
|
hiddens.append(x)
|
||||||
x = downsample(x)
|
x = downsample(x)
|
||||||
|
|
||||||
@@ -1632,11 +1647,14 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
x = self.mid_block2(x, mid_c, t)
|
x = self.mid_block2(x, mid_c, t)
|
||||||
|
|
||||||
for block1, sparse_attn, block2, upsample in self.ups:
|
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||||
x = block1(x, c, t)
|
x = init_block(x, c, t)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
x = block2(x, c, t)
|
|
||||||
|
for resnet_block in resnet_blocks:
|
||||||
|
x = resnet_block(x, c, t)
|
||||||
|
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
return self.final_conv(x)
|
return self.final_conv(x)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ def get_optimizer(
|
|||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.999),
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
filter_by_requires_grad = False,
|
filter_by_requires_grad = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if filter_by_requires_grad:
|
if filter_by_requires_grad:
|
||||||
@@ -20,12 +21,12 @@ def get_optimizer(
|
|||||||
if wd == 0:
|
if wd == 0:
|
||||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||||
|
|
||||||
params = set(params)
|
if group_wd_params:
|
||||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||||
|
|
||||||
param_groups = [
|
params = [
|
||||||
{'params': list(wd_params)},
|
{'params': list(wd_params)},
|
||||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||||
]
|
]
|
||||||
|
|
||||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||||
|
|||||||
@@ -3,7 +3,18 @@ from torchvision import transforms as T
|
|||||||
from pydantic import BaseModel, validator, root_validator
|
from pydantic import BaseModel, validator, root_validator
|
||||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork
|
from x_clip import CLIP as XCLIP
|
||||||
|
from coca_pytorch import CoCa
|
||||||
|
|
||||||
|
from dalle2_pytorch.dalle2_pytorch import (
|
||||||
|
CoCaAdapter,
|
||||||
|
OpenAIClipAdapter,
|
||||||
|
Unet,
|
||||||
|
Decoder,
|
||||||
|
DiffusionPrior,
|
||||||
|
DiffusionPriorNetwork,
|
||||||
|
XClipAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
@@ -16,7 +27,47 @@ def default(val, d):
|
|||||||
def ListOrTuple(inner_type):
|
def ListOrTuple(inner_type):
|
||||||
return Union[List[inner_type], Tuple[inner_type]]
|
return Union[List[inner_type], Tuple[inner_type]]
|
||||||
|
|
||||||
# pydantic classes
|
def SingularOrIterable(inner_type):
|
||||||
|
return Union[inner_type, ListOrTuple(inner_type)]
|
||||||
|
|
||||||
|
# general pydantic classes
|
||||||
|
|
||||||
|
class TrainSplitConfig(BaseModel):
|
||||||
|
train: float = 0.75
|
||||||
|
val: float = 0.15
|
||||||
|
test: float = 0.1
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_all(cls, fields):
|
||||||
|
actual_sum = sum([*fields.values()])
|
||||||
|
if actual_sum != 1.:
|
||||||
|
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
||||||
|
return fields
|
||||||
|
|
||||||
|
class TrackerConfig(BaseModel):
|
||||||
|
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||||
|
data_path: str = './models' # The path where files will be saved locally
|
||||||
|
init_config: Dict[str, Any] = None
|
||||||
|
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||||
|
wandb_project: str = ''
|
||||||
|
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||||
|
|
||||||
|
# diffusion prior pydantic classes
|
||||||
|
|
||||||
|
class AdapterConfig(BaseModel):
|
||||||
|
make: str = "openai"
|
||||||
|
model: str = "ViT-L/14"
|
||||||
|
base_model_kwargs: Dict[str, Any] = None
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
if self.make == "openai":
|
||||||
|
return OpenAIClipAdapter(self.model)
|
||||||
|
elif self.make == "x-clip":
|
||||||
|
return XClipAdapter(XCLIP(**self.base_model_kwargs))
|
||||||
|
elif self.make == "coca":
|
||||||
|
return CoCaAdapter(CoCa(**self.base_model_kwargs))
|
||||||
|
else:
|
||||||
|
raise AttributeError("No adapter with that name is available.")
|
||||||
|
|
||||||
class DiffusionPriorNetworkConfig(BaseModel):
|
class DiffusionPriorNetworkConfig(BaseModel):
|
||||||
dim: int
|
dim: int
|
||||||
@@ -35,8 +86,12 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
|||||||
normformer: bool = False
|
normformer: bool = False
|
||||||
rotary_emb: bool = True
|
rotary_emb: bool = True
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
kwargs = self.dict()
|
||||||
|
return DiffusionPriorNetwork(**kwargs)
|
||||||
|
|
||||||
class DiffusionPriorConfig(BaseModel):
|
class DiffusionPriorConfig(BaseModel):
|
||||||
# only clip-less diffusion prior config for now
|
clip: AdapterConfig = None
|
||||||
net: DiffusionPriorNetworkConfig
|
net: DiffusionPriorNetworkConfig
|
||||||
image_embed_dim: int
|
image_embed_dim: int
|
||||||
image_size: int
|
image_size: int
|
||||||
@@ -46,15 +101,59 @@ class DiffusionPriorConfig(BaseModel):
|
|||||||
loss_type: str = 'l2'
|
loss_type: str = 'l2'
|
||||||
predict_x_start: bool = True
|
predict_x_start: bool = True
|
||||||
beta_schedule: str = 'cosine'
|
beta_schedule: str = 'cosine'
|
||||||
|
condition_on_text_encodings: bool = True
|
||||||
def create(self):
|
|
||||||
kwargs = self.dict()
|
|
||||||
diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net'))
|
|
||||||
return DiffusionPrior(net = diffusion_prior_network, **kwargs)
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
kwargs = self.dict()
|
||||||
|
|
||||||
|
has_clip = exists(kwargs.pop('clip'))
|
||||||
|
kwargs.pop('net')
|
||||||
|
|
||||||
|
clip = None
|
||||||
|
if has_clip:
|
||||||
|
clip = self.clip.create()
|
||||||
|
|
||||||
|
diffusion_prior_network = self.net.create()
|
||||||
|
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
|
||||||
|
|
||||||
|
class DiffusionPriorTrainConfig(BaseModel):
|
||||||
|
epochs: int = 1
|
||||||
|
lr: float = 1.1e-4
|
||||||
|
wd: float = 6.02e-2
|
||||||
|
max_grad_norm: float = 0.5
|
||||||
|
use_ema: bool = True
|
||||||
|
ema_beta: float = 0.99
|
||||||
|
amp: bool = False
|
||||||
|
save_every: int = 10000 # what steps to save on
|
||||||
|
|
||||||
|
class DiffusionPriorDataConfig(BaseModel):
|
||||||
|
image_url: str # path to embeddings folder
|
||||||
|
meta_url: str # path to metadata (captions) for images
|
||||||
|
splits: TrainSplitConfig
|
||||||
|
batch_size: int = 64
|
||||||
|
|
||||||
|
class DiffusionPriorLoadConfig(BaseModel):
|
||||||
|
source: str = None
|
||||||
|
resume: bool = False
|
||||||
|
|
||||||
|
class TrainDiffusionPriorConfig(BaseModel):
|
||||||
|
prior: DiffusionPriorConfig
|
||||||
|
data: DiffusionPriorDataConfig
|
||||||
|
train: DiffusionPriorTrainConfig
|
||||||
|
load: DiffusionPriorLoadConfig
|
||||||
|
tracker: TrackerConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_path(cls, json_path):
|
||||||
|
with open(json_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
|
# decoder pydantic classes
|
||||||
|
|
||||||
class UnetConfig(BaseModel):
|
class UnetConfig(BaseModel):
|
||||||
dim: int
|
dim: int
|
||||||
dim_mults: ListOrTuple(int)
|
dim_mults: ListOrTuple(int)
|
||||||
@@ -94,17 +193,6 @@ class DecoderConfig(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
class TrainSplitConfig(BaseModel):
|
|
||||||
train: float = 0.75
|
|
||||||
val: float = 0.15
|
|
||||||
test: float = 0.1
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validate_all(cls, fields):
|
|
||||||
if sum([*fields.values()]) != 1.:
|
|
||||||
raise ValueError(f'{fields.keys()} must sum to 1.0')
|
|
||||||
return fields
|
|
||||||
|
|
||||||
class DecoderDataConfig(BaseModel):
|
class DecoderDataConfig(BaseModel):
|
||||||
webdataset_base_url: str # path to a webdataset with jpg images
|
webdataset_base_url: str # path to a webdataset with jpg images
|
||||||
embeddings_url: str # path to .npy files with embeddings
|
embeddings_url: str # path to .npy files with embeddings
|
||||||
@@ -137,16 +225,16 @@ class DecoderDataConfig(BaseModel):
|
|||||||
|
|
||||||
class DecoderTrainConfig(BaseModel):
|
class DecoderTrainConfig(BaseModel):
|
||||||
epochs: int = 20
|
epochs: int = 20
|
||||||
lr: float = 1e-4
|
lr: SingularOrIterable(float) = 1e-4
|
||||||
wd: float = 0.01
|
wd: SingularOrIterable(float) = 0.01
|
||||||
max_grad_norm: float = 0.5
|
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||||
save_every_n_samples: int = 100000
|
save_every_n_samples: int = 100000
|
||||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||||
device: str = 'cuda:0'
|
device: str = 'cuda:0'
|
||||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||||
validation_samples: int = None # Same as above but for validation.
|
validation_samples: int = None # Same as above but for validation.
|
||||||
use_ema: bool = True
|
use_ema: bool = True
|
||||||
ema_beta: float = 0.99
|
ema_beta: float = 0.999
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
save_all: bool = False # Whether to preserve all checkpoints
|
save_all: bool = False # Whether to preserve all checkpoints
|
||||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||||
@@ -160,14 +248,6 @@ class DecoderEvaluateConfig(BaseModel):
|
|||||||
KID: Dict[str, Any] = None
|
KID: Dict[str, Any] = None
|
||||||
LPIPS: Dict[str, Any] = None
|
LPIPS: Dict[str, Any] = None
|
||||||
|
|
||||||
class TrackerConfig(BaseModel):
|
|
||||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
|
||||||
data_path: str = './models' # The path where files will be saved locally
|
|
||||||
init_config: Dict[str, Any] = None
|
|
||||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
|
||||||
wandb_project: str = ''
|
|
||||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
|
||||||
|
|
||||||
class DecoderLoadConfig(BaseModel):
|
class DecoderLoadConfig(BaseModel):
|
||||||
source: str = None # Supports file and wandb
|
source: str = None # Supports file and wandb
|
||||||
run_path: str = '' # Used only if source is wandb
|
run_path: str = '' # Used only if source is wandb
|
||||||
|
|||||||
@@ -254,6 +254,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
eps = 1e-6,
|
eps = 1e-6,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
amp = False,
|
amp = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -279,6 +280,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
lr = lr,
|
lr = lr,
|
||||||
wd = wd,
|
wd = wd,
|
||||||
eps = eps,
|
eps = eps,
|
||||||
|
group_wd_params = group_wd_params,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -410,6 +412,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
max_grad_norm = 0.5,
|
max_grad_norm = 0.5,
|
||||||
amp = False,
|
amp = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -435,6 +438,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
lr = unet_lr,
|
lr = unet_lr,
|
||||||
wd = unet_wd,
|
wd = unet_wd,
|
||||||
eps = unet_eps,
|
eps = unet_eps,
|
||||||
|
group_wd_params = group_wd_params,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.5.0',
|
version = '0.5.6',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -347,7 +347,7 @@ def train(
|
|||||||
# Compute evaluation metrics
|
# Compute evaluation metrics
|
||||||
if exists(evaluate_config):
|
if exists(evaluate_config):
|
||||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
|
||||||
tracker.log(evaluation, step=step, verbose=True)
|
tracker.log(evaluation, step=step, verbose=True)
|
||||||
|
|
||||||
# Generate sample images
|
# Generate sample images
|
||||||
|
|||||||
Reference in New Issue
Block a user