mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5943498cf2 | ||
|
|
1bd8a7835a | ||
|
|
f33453df9f | ||
|
|
1e4bb2bafb | ||
|
|
ee75515c7d | ||
|
|
ec68243479 | ||
|
|
3afdcdfe86 | ||
|
|
b9a908ff75 | ||
|
|
e1fe3089df | ||
|
|
6d477d7654 | ||
|
|
531fe4b62f | ||
|
|
ec5a77fc55 | ||
|
|
fac63c61bc | ||
|
|
3d23ba4aa5 | ||
|
|
282c35930f | ||
|
|
27b0f7ca0d | ||
|
|
7b0edf9e42 | ||
|
|
a922a539de |
15
README.md
15
README.md
@@ -20,18 +20,20 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
|
||||
|
||||
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
|
||||
|
||||
<img src="./samples/oxford.png" width="600px" />
|
||||
<img src="./samples/oxford.png" width="450px" />
|
||||
|
||||
*ongoing at 21k steps*
|
||||
|
||||
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
|
||||
|
||||
- <a href="https://github.com/rom1504">Romain</a> has scaled up training to 800 GPUs with the available scripts without any issues
|
||||
|
||||
## Pre-Trained Models
|
||||
|
||||
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
|
||||
- DALL-E 2 🚧
|
||||
- DALL-E 2 🚧 - <a href="https://github.com/LAION-AI/dalle2-laion">DALL-E 2 Laion repository</a>
|
||||
|
||||
## Appreciation
|
||||
|
||||
@@ -1112,15 +1114,6 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022},
|
||||
url = {https://arxiv.org/abs/2204.01697}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Yu2021VectorquantizedIM,
|
||||
title = {Vector-quantized Image Modeling with Improved VQGAN},
|
||||
|
||||
@@ -91,21 +91,83 @@ Each metric can be enabled by setting its configuration. The configuration keys
|
||||
|
||||
**<ins>Tracker</ins>:**
|
||||
|
||||
Selects which tracker to use and configures it.
|
||||
Selects how the experiment will be tracked.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
|
||||
| `data_path` | No | `./models` | Where the tracker will store local data. |
|
||||
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
|
||||
| `data_path` | No | `./.tracker-data` | The path to the folder where temporary tracker data will be saved. |
|
||||
| `overwrite_data_path` | No | `False` | If true, the data path will be overwritten. Otherwise, you need to delete it yourself. |
|
||||
| `log` | Yes | N/A | Logging configuration. |
|
||||
| `load` | No | `None` | Checkpoint loading configuration. |
|
||||
| `save` | Yes | N/A | Checkpoint/Model saving configuration. |
|
||||
Tracking is split up into three sections:
|
||||
* Log: Where to save run metadata and image output. Options are `console` or `wandb`.
|
||||
* Load: Where to load a checkpoint from. Options are `local`, `url`, or `wandb`.
|
||||
* Save: Where to save a checkpoint to. Options are `local`, `huggingface`, or `wandb`.
|
||||
|
||||
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
|
||||
**Logging:**
|
||||
|
||||
**<ins>Load</ins>:**
|
||||
|
||||
Selects where to load a pretrained model from.
|
||||
If using `console` there is no further configuration than setting `log_type` to `console`.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `source` | No | `None` | Supports `file` or `wandb`. |
|
||||
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
|
||||
| `log_type` | Yes | N/A | Must be `console`. |
|
||||
|
||||
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).
|
||||
If using `wandb`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `log_type` | Yes | N/A | Must be `wandb`. |
|
||||
| `wandb_entity` | Yes | N/A | The wandb entity to log to. |
|
||||
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
|
||||
| `wandb_run_name` | No | `None` | The wandb run name. |
|
||||
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
|
||||
| `wandb_resume` | No | `False` | Whether to resume an old run. |
|
||||
|
||||
**Loading:**
|
||||
|
||||
If using `local`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `load_from` | Yes | N/A | Must be `local`. |
|
||||
| `file_path` | Yes | N/A | The path to the checkpoint file. |
|
||||
|
||||
If using `url`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `load_from` | Yes | N/A | Must be `url`. |
|
||||
| `url` | Yes | N/A | The url of the checkpoint file. |
|
||||
|
||||
If using `wandb`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `load_from` | Yes | N/A | Must be `wandb`. |
|
||||
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the run that is being resumed. |
|
||||
| `wandb_file_path` | Yes | N/A | The path to the checkpoint file in the W&B file system. |
|
||||
|
||||
**Saving:**
|
||||
Unlike `log` and `load`, `save` may be an array of options so that you can save to different locations in a run.
|
||||
|
||||
All save locations have these configuration options
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
|
||||
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
|
||||
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
|
||||
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
|
||||
|
||||
If using `local`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `local`. |
|
||||
|
||||
If using `huggingface`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `huggingface`. |
|
||||
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
|
||||
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
|
||||
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
|
||||
|
||||
If using `wandb`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `wandb`. |
|
||||
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the current run. You will almost always want this to be `None`. |
|
||||
|
||||
@@ -80,20 +80,32 @@
|
||||
}
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "console",
|
||||
"data_path": "./models",
|
||||
"overwrite_data_path": true,
|
||||
|
||||
"wandb_entity": "",
|
||||
"wandb_project": "",
|
||||
"log": {
|
||||
"log_type": "wandb",
|
||||
|
||||
"verbose": false
|
||||
},
|
||||
"load": {
|
||||
"source": null,
|
||||
"wandb_entity": "your_wandb",
|
||||
"wandb_project": "your_project",
|
||||
|
||||
"run_path": "",
|
||||
"file_path": "",
|
||||
"verbose": true
|
||||
},
|
||||
|
||||
"resume": false
|
||||
"load": {
|
||||
"load_from": null
|
||||
},
|
||||
|
||||
"save": [{
|
||||
"save_to": "wandb"
|
||||
}, {
|
||||
"save_to": "huggingface",
|
||||
"huggingface_repo": "Veldrovive/test_model",
|
||||
|
||||
"save_all": true,
|
||||
"save_latest": true,
|
||||
"save_best": true,
|
||||
|
||||
"save_type": "model"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,11 +63,16 @@ def default(val, d):
|
||||
return val
|
||||
return d() if callable(d) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
def cast_tuple(val, length = None):
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
out = val if isinstance(val, tuple) else ((val,) * default(length, 1))
|
||||
|
||||
if exists(length):
|
||||
assert len(out) == length
|
||||
|
||||
return out
|
||||
|
||||
def module_device(module):
|
||||
return next(module.parameters()).device
|
||||
@@ -330,6 +335,10 @@ def approx_standard_normal_cdf(x):
|
||||
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||
assert x.shape == means.shape == log_scales.shape
|
||||
|
||||
# attempting to correct nan gradients when learned variance is turned on
|
||||
# in the setting of deepspeed fp16
|
||||
eps = 1e-12 if x.dtype == torch.float32 else 1e-5
|
||||
|
||||
centered_x = x - means
|
||||
inv_stdv = torch.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_x + 1. / 255.)
|
||||
@@ -344,7 +353,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||
log_cdf_plus,
|
||||
torch.where(x > thres,
|
||||
log_one_minus_cdf_min,
|
||||
log(cdf_delta)))
|
||||
log(cdf_delta, eps = eps)))
|
||||
|
||||
return log_probs
|
||||
|
||||
@@ -485,14 +494,16 @@ class NoiseScheduler(nn.Module):
|
||||
# diffusion prior
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.register_buffer("beta", torch.zeros(dim))
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
||||
|
||||
x = x / x.amax(dim = -1, keepdim = True).detach()
|
||||
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = -1, keepdim = True)
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
@@ -501,10 +512,10 @@ class ChanLayerNorm(nn.Module):
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = x / x.amax(dim = 1, keepdim = True).detach()
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g
|
||||
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
@@ -624,10 +635,13 @@ class Attention(nn.Module):
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
rotary_emb = None
|
||||
rotary_emb = None,
|
||||
pb_relax_alpha = 32 ** 2
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.pb_relax_alpha = pb_relax_alpha
|
||||
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
|
||||
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
@@ -691,7 +705,10 @@ class Attention(nn.Module):
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
sim = sim * self.pb_relax_alpha
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
@@ -1093,7 +1110,11 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
# decoder
|
||||
|
||||
def Upsample(dim, dim_out = None):
|
||||
def ConvTransposeUpsample(dim, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
||||
|
||||
def NearestUpsample(dim, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.Sequential(
|
||||
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
||||
@@ -1110,11 +1131,12 @@ class SinusoidalPosEmb(nn.Module):
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
dtype, device = x.dtype, x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
|
||||
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
||||
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
||||
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
@@ -1201,10 +1223,12 @@ class CrossAttention(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
norm_context = False
|
||||
norm_context = False,
|
||||
pb_relax_alpha = 32 ** 2
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.pb_relax_alpha = pb_relax_alpha
|
||||
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
@@ -1250,26 +1274,15 @@ class CrossAttention(nn.Module):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
sim = sim * self.pb_relax_alpha
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class GridAttention(nn.Module):
|
||||
def __init__(self, *args, window_size = 8, **kwargs):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.attn = Attention(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
wsz = self.window_size
|
||||
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
|
||||
out = self.attn(x)
|
||||
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
||||
return out
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1351,6 +1364,7 @@ class Unet(nn.Module):
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
channels_out = None,
|
||||
self_attn = False,
|
||||
attn_dim_head = 32,
|
||||
attn_heads = 16,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
@@ -1369,6 +1383,8 @@ class Unet(nn.Module):
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
memory_efficient = False,
|
||||
scale_skip_connection = False,
|
||||
nearest_upsample = False,
|
||||
final_conv_kernel_size = 1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1395,6 +1411,8 @@ class Unet(nn.Module):
|
||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
num_stages = len(in_out)
|
||||
|
||||
# time, image embeddings, and optional text encoding
|
||||
|
||||
cond_dim = default(cond_dim, dim)
|
||||
@@ -1458,14 +1476,16 @@ class Unet(nn.Module):
|
||||
|
||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||
|
||||
self_attn = cast_tuple(self_attn, num_stages)
|
||||
|
||||
create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs)))
|
||||
|
||||
# resnet block klass
|
||||
|
||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||
resnet_groups = cast_tuple(resnet_groups, num_stages)
|
||||
top_level_resnet_group = first(resnet_groups)
|
||||
|
||||
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_stages)
|
||||
|
||||
# downsample klass
|
||||
|
||||
@@ -1473,6 +1493,10 @@ class Unet(nn.Module):
|
||||
if cross_embed_downsample:
|
||||
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
||||
|
||||
# upsample klass
|
||||
|
||||
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
||||
|
||||
# give memory efficient unet an initial resnet block
|
||||
|
||||
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
|
||||
@@ -1483,9 +1507,9 @@ class Unet(nn.Module):
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
skip_connect_dims = [] # keeping track of skip connection dimensions
|
||||
skip_connect_dims = [] # keeping track of skip connection dimensions
|
||||
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
|
||||
is_first = ind == 0
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
@@ -1493,35 +1517,47 @@ class Unet(nn.Module):
|
||||
dim_layer = dim_out if memory_efficient else dim_in
|
||||
skip_connect_dims.append(dim_layer)
|
||||
|
||||
attention = nn.Identity()
|
||||
if layer_self_attn:
|
||||
attention = create_self_attn(dim_layer)
|
||||
elif sparse_attn:
|
||||
attention = Residual(LinearAttention(dim_layer, **attn_kwargs))
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
||||
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_layer, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
attention,
|
||||
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_attn = create_self_attn(mid_dim)
|
||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
skip_connect_dim = skip_connect_dims.pop()
|
||||
|
||||
attention = nn.Identity()
|
||||
if layer_self_attn:
|
||||
attention = create_self_attn(dim_out)
|
||||
elif sparse_attn:
|
||||
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
Upsample(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||
attention,
|
||||
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||
]))
|
||||
|
||||
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
||||
self.to_out = nn.Conv2d(dim, self.channels_out, 3, padding = 1)
|
||||
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
@@ -1595,6 +1631,7 @@ class Unet(nn.Module):
|
||||
|
||||
# time conditioning
|
||||
|
||||
time = time.type_as(x)
|
||||
time_hiddens = self.to_time_hiddens(time)
|
||||
|
||||
time_tokens = self.to_time_tokens(time_hiddens)
|
||||
@@ -1694,18 +1731,19 @@ class Unet(nn.Module):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
|
||||
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
|
||||
if exists(pre_downsample):
|
||||
x = pre_downsample(x)
|
||||
|
||||
x = init_block(x, t, c)
|
||||
x = sparse_attn(x)
|
||||
hiddens.append(x)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, t, c)
|
||||
hiddens.append(x)
|
||||
|
||||
x = attn(x)
|
||||
hiddens.append(x)
|
||||
|
||||
if exists(post_downsample):
|
||||
x = post_downsample(x)
|
||||
|
||||
@@ -1716,17 +1754,17 @@ class Unet(nn.Module):
|
||||
|
||||
x = self.mid_block2(x, t, mid_c)
|
||||
|
||||
connect_skip = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
||||
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
||||
|
||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||
for init_block, resnet_blocks, attn, upsample in self.ups:
|
||||
x = connect_skip(x)
|
||||
x = init_block(x, t, c)
|
||||
x = sparse_attn(x)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = connect_skip(x)
|
||||
x = resnet_block(x, t, c)
|
||||
|
||||
x = attn(x)
|
||||
x = upsample(x)
|
||||
|
||||
x = torch.cat((x, r), dim = 1)
|
||||
@@ -2229,7 +2267,8 @@ class Decoder(nn.Module):
|
||||
image_embed = None,
|
||||
text_encodings = None,
|
||||
text_mask = None,
|
||||
unet_number = None
|
||||
unet_number = None,
|
||||
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
|
||||
):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
@@ -2279,7 +2318,12 @@ class Decoder(nn.Module):
|
||||
image = vae.encode(image)
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||
|
||||
if not return_lowres_cond_image:
|
||||
return losses
|
||||
|
||||
return losses, lowres_cond_img
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import urllib.request
|
||||
import os
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
import shutil
|
||||
from itertools import zip_longest
|
||||
from typing import Optional, List, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch.utils import import_or_print_error
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||
|
||||
# constants
|
||||
|
||||
@@ -27,126 +30,484 @@ def load_wandb_file(run_path, file_path, **kwargs):
|
||||
def load_local_file(file_path, **kwargs):
|
||||
return file_path
|
||||
|
||||
# base class
|
||||
|
||||
class BaseTracker(nn.Module):
|
||||
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
||||
super().__init__()
|
||||
class BaseLogger:
|
||||
"""
|
||||
An abstract class representing an object that can log data.
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
verbose (bool): Whether of not to always print logs to the console.
|
||||
"""
|
||||
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
self.data_path.mkdir(parents = True, exist_ok = True)
|
||||
self.verbose = verbose
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log_images(self, images, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def recall_state_dict(self, recall_source, *args, **kwargs):
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
"""
|
||||
Loads a state dict from any source.
|
||||
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
|
||||
this should not be linked to any individual tracker.
|
||||
Initializes the logger.
|
||||
Errors if the logger is invalid.
|
||||
"""
|
||||
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
|
||||
if recall_source == 'wandb':
|
||||
return torch.load(load_wandb_file(*args, **kwargs))
|
||||
elif recall_source == 'local':
|
||||
return torch.load(load_local_file(*args, **kwargs))
|
||||
else:
|
||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def recall_file(self, recall_source, *args, **kwargs):
|
||||
if recall_source == 'wandb':
|
||||
return load_wandb_file(*args, **kwargs)
|
||||
elif recall_source == 'local':
|
||||
return load_local_file(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
||||
def log(self, log, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
# Tracker that no-ops all calls except for recall
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
class DummyTracker(BaseTracker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def log_file(self, file_path, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
pass
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
pass
|
||||
class ConsoleLogger(BaseLogger):
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
print("Logging to console")
|
||||
|
||||
def log_images(self, images, **kwargs):
|
||||
pass
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
pass
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
pass
|
||||
|
||||
# basic stdout class
|
||||
|
||||
class ConsoleTracker(BaseTracker):
|
||||
def init(self, **config):
|
||||
print(config)
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
def log(self, log, **kwargs) -> None:
|
||||
print(log)
|
||||
|
||||
def log_images(self, images, **kwargs): # noop for logging images
|
||||
pass
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
torch.save(state_dict, str(self.data_path / relative_path))
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
# This is a no-op for local file systems since it is already saved locally
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
# basic wandb class
|
||||
def log_file(self, file_path, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
class WandbTracker(BaseTracker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
print(error_string)
|
||||
|
||||
class WandbLogger(BaseLogger):
|
||||
"""
|
||||
Logs to a wandb run.
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
wandb_entity (str): The wandb entity to log to.
|
||||
wandb_project (str): The wandb project to log to.
|
||||
wandb_run_id (str): The wandb run id to resume.
|
||||
wandb_run_name (str): The wandb run name to use.
|
||||
wandb_resume (bool): Whether to resume a wandb run.
|
||||
"""
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
wandb_entity: str,
|
||||
wandb_project: str,
|
||||
wandb_run_id: Optional[str] = None,
|
||||
wandb_run_name: Optional[str] = None,
|
||||
wandb_resume: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.entity = wandb_entity
|
||||
self.project = wandb_project
|
||||
self.run_id = wandb_run_id
|
||||
self.run_name = wandb_run_name
|
||||
self.resume = wandb_resume
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
|
||||
assert self.project is not None, "wandb_project must be specified for wandb logger"
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
# Initializes the wandb run
|
||||
init_object = {
|
||||
"entity": self.entity,
|
||||
"project": self.project,
|
||||
"config": {**full_config.dict(), **extra_config}
|
||||
}
|
||||
if self.run_name is not None:
|
||||
init_object['name'] = self.run_name
|
||||
if self.resume:
|
||||
assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'
|
||||
if self.run_name is not None:
|
||||
print("You are renaming a run. I hope that is what you intended.")
|
||||
init_object['resume'] = 'must'
|
||||
init_object['id'] = self.run_id
|
||||
|
||||
def init(self, **config):
|
||||
self.wandb.init(**config)
|
||||
self.wandb.init(**init_object)
|
||||
print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}")
|
||||
|
||||
def log(self, log, verbose=False, **kwargs):
|
||||
if verbose:
|
||||
def log(self, log, **kwargs) -> None:
|
||||
if self.verbose:
|
||||
print(log)
|
||||
self.wandb.log(log, **kwargs)
|
||||
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs):
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||
"""
|
||||
Takes a tensor of images and a list of captions and logs them to wandb.
|
||||
"""
|
||||
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
||||
self.log({ image_section: wandb_images }, **kwargs)
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
"""
|
||||
Saves a state_dict to disk and uploads it
|
||||
"""
|
||||
full_path = str(self.data_path / relative_path)
|
||||
torch.save(state_dict, full_path)
|
||||
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
|
||||
self.wandb.log({ image_section: wandb_images }, **kwargs)
|
||||
|
||||
def save_file(self, file_path, base_path=None, **kwargs):
|
||||
"""
|
||||
Uploads a file from disk to wandb
|
||||
"""
|
||||
def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:
|
||||
if base_path is None:
|
||||
base_path = self.data_path
|
||||
# Then we take the basepath as the parent of the file_path
|
||||
base_path = Path(file_path).parent
|
||||
self.wandb.save(str(file_path), base_path = str(base_path))
|
||||
|
||||
def log_error(self, error_string, step=None, **kwargs) -> None:
|
||||
if self.verbose:
|
||||
print(error_string)
|
||||
self.wandb.log({"error": error_string, **kwargs}, step=step)
|
||||
|
||||
logger_type_map = {
|
||||
'console': ConsoleLogger,
|
||||
'wandb': WandbLogger,
|
||||
}
|
||||
def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:
|
||||
if logger_type == 'custom':
|
||||
raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')
|
||||
try:
|
||||
logger_class = logger_type_map[logger_type]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')
|
||||
return logger_class(data_path, **kwargs)
|
||||
|
||||
class BaseLoader:
|
||||
"""
|
||||
An abstract class representing an object that can load a model checkpoint.
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
"""
|
||||
def __init__(self, data_path: str, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def recall() -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
class UrlLoader(BaseLoader):
|
||||
"""
|
||||
A loader that downloads the file from a url and loads it
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
url (str): The url to download the file from.
|
||||
"""
|
||||
def __init__(self, data_path: str, url: str, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.url = url
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the file exists to be downloaded
|
||||
pass # TODO: Actually implement that
|
||||
|
||||
def recall(self) -> dict:
|
||||
# Download the file
|
||||
save_path = self.data_path / 'loaded_checkpoint.pth'
|
||||
urllib.request.urlretrieve(self.url, str(save_path))
|
||||
# Load the file
|
||||
return torch.load(str(save_path), map_location='cpu')
|
||||
|
||||
|
||||
class LocalLoader(BaseLoader):
|
||||
"""
|
||||
A loader that loads a file from a local path
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
file_path (str): The path to the file to load.
|
||||
"""
|
||||
def __init__(self, data_path: str, file_path: str, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.file_path = Path(file_path)
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the file exists to be loaded
|
||||
if not self.file_path.exists():
|
||||
raise FileNotFoundError(f'Model not found at {self.file_path}')
|
||||
|
||||
def recall(self) -> dict:
|
||||
# Load the file
|
||||
return torch.load(str(self.file_path), map_location='cpu')
|
||||
|
||||
class WandbLoader(BaseLoader):
|
||||
"""
|
||||
A loader that loads a model from an existing wandb run
|
||||
"""
|
||||
def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.run_path = wandb_run_path
|
||||
self.file_path = wandb_file_path
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||
# Make sure the file can be downloaded
|
||||
if self.wandb.run is not None and self.run_path is None:
|
||||
self.run_path = self.wandb.run.path
|
||||
assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'
|
||||
assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'
|
||||
assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'
|
||||
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
pass # TODO: Actually implement that
|
||||
|
||||
def recall(self) -> dict:
|
||||
file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)
|
||||
return torch.load(file_reference.name, map_location='cpu')
|
||||
|
||||
loader_type_map = {
|
||||
'url': UrlLoader,
|
||||
'local': LocalLoader,
|
||||
'wandb': WandbLoader,
|
||||
}
|
||||
def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
|
||||
if loader_type == 'custom':
|
||||
raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')
|
||||
try:
|
||||
loader_class = loader_type_map[loader_type]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')
|
||||
return loader_class(data_path, **kwargs)
|
||||
|
||||
class BaseSaver:
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
|
||||
save_best_to: Optional[Union[str, bool]] = 'best.pth',
|
||||
save_meta_to: str = './',
|
||||
save_type: str = 'checkpoint',
|
||||
**kwargs
|
||||
):
|
||||
self.data_path = Path(data_path)
|
||||
self.save_latest_to = save_latest_to
|
||||
self.saving_latest = save_latest_to is not None and save_latest_to is not False
|
||||
self.save_best_to = save_best_to
|
||||
self.saving_best = save_best_to is not None and save_best_to is not False
|
||||
self.save_meta_to = save_meta_to
|
||||
self.save_type = save_type
|
||||
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
|
||||
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
|
||||
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:
|
||||
"""
|
||||
Save a general file under save_meta_to
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
class LocalSaver(BaseSaver):
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(data_path, **kwargs)
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the directory exists to be saved to
|
||||
print(f"Saving {self.save_type} locally")
|
||||
if not self.data_path.exists():
|
||||
self.data_path.mkdir(parents=True)
|
||||
|
||||
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
|
||||
# Copy the file to save_path
|
||||
save_path_file_name = Path(save_path).name
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
|
||||
shutil.copy(local_path, save_path)
|
||||
|
||||
class WandbSaver(BaseSaver):
|
||||
def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.run_path = wandb_run_path
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
# Makes sure that the user can upload tot his run
|
||||
if self.run_path is not None:
|
||||
entity, project, run_id = self.run_path.split("/")
|
||||
self.run = self.wandb.init(entity=entity, project=project, id=run_id)
|
||||
else:
|
||||
assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'
|
||||
self.run = self.wandb.run
|
||||
# TODO: Now actually check if upload is possible
|
||||
print(f"Saving to wandb run {self.run.path}-{self.run.name}")
|
||||
|
||||
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
|
||||
# In order to log something in the correct place in wandb, we need to have the same file structure here
|
||||
save_path_file_name = Path(save_path).name
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}")
|
||||
save_path = Path(self.data_path) / save_path
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(local_path, save_path)
|
||||
self.run.save(str(save_path), base_path = str(self.data_path), policy='now')
|
||||
|
||||
class HuggingfaceSaver(BaseSaver):
|
||||
def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.huggingface_repo = huggingface_repo
|
||||
self.token_path = token_path
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs):
|
||||
# Makes sure this user can upload to the repo
|
||||
self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')
|
||||
try:
|
||||
identity = self.hub.whoami() # Errors if not logged in
|
||||
# Then we are logged in
|
||||
except:
|
||||
# We are not logged in. Use the token_path to set the token.
|
||||
if not os.path.exists(self.token_path):
|
||||
raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.")
|
||||
with open(self.token_path, "r") as f:
|
||||
token = f.read().strip()
|
||||
self.hub.HfApi.set_access_token(token)
|
||||
identity = self.hub.whoami()
|
||||
print(f"Saving to huggingface repo {self.huggingface_repo}")
|
||||
|
||||
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
|
||||
# Saving to huggingface is easy, we just need to upload the file with the correct name
|
||||
save_path_file_name = Path(save_path).name
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}")
|
||||
self.hub.upload_file(
|
||||
path_or_fileobj=str(local_path),
|
||||
path_in_repo=str(save_path),
|
||||
repo_id=self.huggingface_repo
|
||||
)
|
||||
|
||||
saver_type_map = {
|
||||
'local': LocalSaver,
|
||||
'wandb': WandbSaver,
|
||||
'huggingface': HuggingfaceSaver
|
||||
}
|
||||
def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:
|
||||
if saver_type == 'custom':
|
||||
raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')
|
||||
try:
|
||||
saver_class = saver_type_map[saver_type]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')
|
||||
return saver_class(data_path, **kwargs)
|
||||
|
||||
|
||||
class Tracker:
|
||||
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
|
||||
self.data_path = Path(data_path)
|
||||
if not dummy_mode:
|
||||
if overwrite_data_path:
|
||||
if self.data_path.exists():
|
||||
shutil.rmtree(self.data_path)
|
||||
self.data_path.mkdir(parents=True)
|
||||
else:
|
||||
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
|
||||
if not self.data_path.exists():
|
||||
self.data_path.mkdir(parents=True)
|
||||
self.logger: BaseLogger = None
|
||||
self.loader: Optional[BaseLoader] = None
|
||||
self.savers: List[BaseSaver]= []
|
||||
self.dummy_mode = dummy_mode
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict):
|
||||
assert self.logger is not None, '`logger` must be set before `init` is called'
|
||||
if self.dummy_mode:
|
||||
# The only thing we need is a loader
|
||||
if self.loader is not None:
|
||||
self.loader.init(self.logger)
|
||||
return
|
||||
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
|
||||
self.logger.init(full_config, extra_config)
|
||||
if self.loader is not None:
|
||||
self.loader.init(self.logger)
|
||||
for saver in self.savers:
|
||||
saver.init(self.logger)
|
||||
|
||||
def add_logger(self, logger: BaseLogger):
|
||||
self.logger = logger
|
||||
|
||||
def add_loader(self, loader: BaseLoader):
|
||||
self.loader = loader
|
||||
|
||||
def add_saver(self, saver: BaseSaver):
|
||||
self.savers.append(saver)
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
self.logger.log(*args, **kwargs)
|
||||
|
||||
def log_images(self, *args, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
self.logger.log_images(*args, **kwargs)
|
||||
|
||||
def log_file(self, *args, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
self.logger.log_file(*args, **kwargs)
|
||||
|
||||
def save_config(self, current_config_path: str, config_name = 'config.json'):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
# Save the config under config_name in the root folder of data_path
|
||||
shutil.copy(current_config_path, self.data_path / config_name)
|
||||
for saver in self.savers:
|
||||
remote_path = Path(saver.save_meta_to) / config_name
|
||||
saver.save_file(current_config_path, str(remote_path))
|
||||
|
||||
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
|
||||
"""
|
||||
Gets the state dict to be saved and writes it to file_path.
|
||||
If save_type is 'checkpoint', we save the entire trainer state dict.
|
||||
If save_type is 'model', we save only the model state dict.
|
||||
"""
|
||||
assert save_type in ['checkpoint', 'model']
|
||||
if save_type == 'checkpoint':
|
||||
trainer.save(file_path, overwrite=True, **kwargs)
|
||||
elif save_type == 'model':
|
||||
if isinstance(trainer, DiffusionPriorTrainer):
|
||||
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||
state_dict = trainer.unwrap_model(prior).state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
elif isinstance(trainer, DecoderTrainer):
|
||||
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||
if trainer.use_ema:
|
||||
trainable_unets = decoder.unets
|
||||
decoder.unets = trainer.unets # Swap EMA unets in
|
||||
state_dict = decoder.state_dict()
|
||||
decoder.unets = trainable_unets # Swap back
|
||||
else:
|
||||
state_dict = decoder.state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
else:
|
||||
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
|
||||
return Path(file_path)
|
||||
|
||||
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
if not is_best and not is_latest:
|
||||
# Nothing to do
|
||||
return
|
||||
# Save the checkpoint and model to data_path
|
||||
checkpoint_path = self.data_path / 'checkpoint.pth'
|
||||
self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)
|
||||
model_path = self.data_path / 'model.pth'
|
||||
self._save_state_dict(trainer, 'model', model_path, **kwargs)
|
||||
print("Saved cached models")
|
||||
# Call the save methods on the savers
|
||||
for saver in self.savers:
|
||||
local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path
|
||||
if saver.saving_latest and is_latest:
|
||||
latest_checkpoint_path = saver.save_latest_to.format(**kwargs)
|
||||
try:
|
||||
saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||
print(f'Error saving checkpoint: {e}')
|
||||
if saver.saving_best and is_best:
|
||||
best_checkpoint_path = saver.save_best_to.format(**kwargs)
|
||||
try:
|
||||
saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||
print(f'Error saving checkpoint: {e}')
|
||||
|
||||
def recall(self):
|
||||
if self.loader is not None:
|
||||
return self.loader.recall()
|
||||
else:
|
||||
raise ValueError('No loader specified')
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from dalle2_pytorch.dalle2_pytorch import (
|
||||
DiffusionPriorNetwork,
|
||||
XClipAdapter
|
||||
)
|
||||
from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver
|
||||
|
||||
# helper functions
|
||||
|
||||
@@ -44,13 +45,66 @@ class TrainSplitConfig(BaseModel):
|
||||
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
||||
return fields
|
||||
|
||||
class TrackerLogConfig(BaseModel):
|
||||
log_type: str = 'console'
|
||||
verbose: bool = False
|
||||
|
||||
class Config:
|
||||
# Each individual log type has it's own arguments that will be passed through the config
|
||||
extra = "allow"
|
||||
|
||||
def create(self, data_path: str):
|
||||
kwargs = self.dict()
|
||||
return create_logger(self.log_type, data_path, **kwargs)
|
||||
|
||||
class TrackerLoadConfig(BaseModel):
|
||||
load_from: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self, data_path: str):
|
||||
kwargs = self.dict()
|
||||
if self.load_from is None:
|
||||
return None
|
||||
return create_loader(self.load_from, data_path, **kwargs)
|
||||
|
||||
class TrackerSaveConfig(BaseModel):
|
||||
save_to: str = 'local'
|
||||
save_all: bool = False
|
||||
save_latest: bool = True
|
||||
save_best: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self, data_path: str):
|
||||
kwargs = self.dict()
|
||||
return create_saver(self.save_to, data_path, **kwargs)
|
||||
|
||||
class TrackerConfig(BaseModel):
|
||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||
data_path: str = './models' # The path where files will be saved locally
|
||||
init_config: Dict[str, Any] = None
|
||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||
wandb_project: str = ''
|
||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||
data_path: str = '.tracker_data'
|
||||
overwrite_data_path: bool = False
|
||||
log: TrackerLogConfig
|
||||
load: Optional[TrackerLoadConfig]
|
||||
save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
|
||||
|
||||
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
|
||||
tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)
|
||||
# Add the logger
|
||||
tracker.add_logger(self.log.create(self.data_path))
|
||||
# Add the loader
|
||||
if self.load is not None:
|
||||
tracker.add_loader(self.load.create(self.data_path))
|
||||
# Add the saver or savers
|
||||
if isinstance(self.save, list):
|
||||
for save_config in self.save:
|
||||
tracker.add_saver(save_config.create(self.data_path))
|
||||
else:
|
||||
tracker.add_saver(self.save.create(self.data_path))
|
||||
# Initialize all the components and verify that all data is valid
|
||||
tracker.init(full_config, extra_config)
|
||||
return tracker
|
||||
|
||||
# diffusion prior pydantic classes
|
||||
|
||||
@@ -162,6 +216,7 @@ class UnetConfig(BaseModel):
|
||||
cond_on_text_encodings: bool = None
|
||||
cond_dim: int = None
|
||||
channels: int = 3
|
||||
self_attn: ListOrTuple(int)
|
||||
attn_dim_head: int = 32
|
||||
attn_heads: int = 16
|
||||
|
||||
@@ -238,6 +293,8 @@ class DecoderTrainConfig(BaseModel):
|
||||
epochs: int = 20
|
||||
lr: SingularOrIterable(float) = 1e-4
|
||||
wd: SingularOrIterable(float) = 0.01
|
||||
warmup_steps: Optional[SingularOrIterable(int)] = None
|
||||
find_unused_parameters: bool = True
|
||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||
save_every_n_samples: int = 100000
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
@@ -247,9 +304,6 @@ class DecoderTrainConfig(BaseModel):
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.999
|
||||
amp: bool = False
|
||||
save_all: bool = False # Whether to preserve all checkpoints
|
||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||
save_best: bool = True # Whether to save the best checkpoint
|
||||
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||
|
||||
class DecoderEvaluateConfig(BaseModel):
|
||||
@@ -271,7 +325,6 @@ class TrainDecoderConfig(BaseModel):
|
||||
train: DecoderTrainConfig
|
||||
evaluate: DecoderEvaluateConfig
|
||||
tracker: TrackerConfig
|
||||
load: DecoderLoadConfig
|
||||
seed: int = 0
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -3,10 +3,13 @@ import copy
|
||||
from pathlib import Path
|
||||
from math import ceil
|
||||
from functools import partial, wraps
|
||||
from contextlib import nullcontext
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
@@ -14,6 +17,8 @@ from dalle2_pytorch.optimizer import get_optimizer
|
||||
from dalle2_pytorch.version import __version__
|
||||
from packaging import version
|
||||
|
||||
import pytorch_warmup as warmup
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
from accelerate import Accelerator
|
||||
@@ -162,19 +167,32 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
group_wd_params = True,
|
||||
device = None,
|
||||
accelerator = None,
|
||||
verbose = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
|
||||
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
# verbosity
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
# assign some helpful member vars
|
||||
|
||||
self.accelerator = accelerator
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
self.text_conditioned = diffusion_prior.condition_on_text_encodings
|
||||
|
||||
# setting the device
|
||||
|
||||
if not exists(accelerator) and not exists(device):
|
||||
diffusion_prior_device = next(diffusion_prior.parameters()).device
|
||||
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
|
||||
self.device = diffusion_prior_device
|
||||
else:
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
|
||||
# save model
|
||||
|
||||
self.diffusion_prior = diffusion_prior
|
||||
@@ -210,11 +228,14 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# track steps internally
|
||||
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
self.register_buffer('step', torch.tensor([0], device = device))
|
||||
|
||||
# accelerator wrappers
|
||||
|
||||
def print(self, msg):
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.print(msg)
|
||||
else:
|
||||
@@ -428,6 +449,7 @@ class DecoderTrainer(nn.Module):
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
warmup_steps = None,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
@@ -449,13 +471,15 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
||||
|
||||
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
optimizers = []
|
||||
schedulers = []
|
||||
warmup_schedulers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
@@ -467,6 +491,13 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
optimizers.append(optimizer)
|
||||
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
|
||||
schedulers.append(scheduler)
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||
|
||||
@@ -474,15 +505,27 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
schedulers = list(self.accelerator.prepare(*schedulers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
# store optimizers
|
||||
|
||||
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
||||
setattr(self, f'optim{opt_ind}', optimizer)
|
||||
|
||||
# store schedulers
|
||||
|
||||
for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):
|
||||
setattr(self, f'sched{sched_ind}', scheduler)
|
||||
|
||||
# store warmup schedulers
|
||||
|
||||
self.warmup_schedulers = warmup_schedulers
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
@@ -491,7 +534,7 @@ class DecoderTrainer(nn.Module):
|
||||
save_obj = dict(
|
||||
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
|
||||
version = __version__,
|
||||
step = self.step.item(),
|
||||
steps = self.steps.cpu(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -505,30 +548,38 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.accelerator.save(save_obj, str(path))
|
||||
|
||||
def load_state_dict(self, loaded_obj, only_model = False, strict = True):
|
||||
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
||||
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||
|
||||
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.steps.copy_(loaded_obj['steps'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):
|
||||
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
warmup_scheduler = self.warmup_schedulers[ind]
|
||||
|
||||
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if exists(warmup_scheduler):
|
||||
warmup_scheduler.last_step = last_step
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
def load(self, path, only_model = False, strict = True):
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path), map_location = 'cpu')
|
||||
|
||||
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
||||
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||
|
||||
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
@@ -536,6 +587,12 @@ class DecoderTrainer(nn.Module):
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
def increment_step(self, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
|
||||
unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)
|
||||
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
|
||||
|
||||
def update(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
@@ -544,17 +601,25 @@ class DecoderTrainer(nn.Module):
|
||||
index = unet_number - 1
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
scheduler = getattr(self, f'sched{index}')
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
warmup_scheduler = self.warmup_schedulers[index]
|
||||
scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext
|
||||
|
||||
with scheduler_context():
|
||||
scheduler.step()
|
||||
|
||||
if self.use_ema:
|
||||
ema_unet = self.ema_unets[index]
|
||||
ema_unet.update()
|
||||
|
||||
self.step += 1
|
||||
self.increment_step(unet_number)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@@ -604,7 +669,6 @@ class DecoderTrainer(nn.Module):
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
# with autocast(enabled = self.amp):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import time
|
||||
import importlib
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# time helpers
|
||||
|
||||
class Timer:
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.14.1'
|
||||
__version__ = '0.16.11'
|
||||
|
||||
1
setup.py
1
setup.py
@@ -37,6 +37,7 @@ setup(
|
||||
'packaging',
|
||||
'pillow',
|
||||
'pydantic',
|
||||
'pytorch-warmup',
|
||||
'resize-right>=0.0.2',
|
||||
'rotary-embedding-torch',
|
||||
'torch>=1.10',
|
||||
|
||||
106
train_decoder.py
106
train_decoder.py
@@ -1,11 +1,12 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from dalle2_pytorch.trainer import DecoderTrainer
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.trackers import Tracker
|
||||
from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to
|
||||
from clip import tokenize
|
||||
|
||||
import torchvision
|
||||
@@ -239,42 +240,33 @@ def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=Fa
|
||||
metrics[metric_name] = metrics_tensor[i].item()
|
||||
return metrics
|
||||
|
||||
def save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, relative_paths):
|
||||
def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):
|
||||
"""
|
||||
Logs the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
if isinstance(relative_paths, str):
|
||||
relative_paths = [relative_paths]
|
||||
for relative_path in relative_paths:
|
||||
local_path = str(tracker.data_path / relative_path)
|
||||
trainer.save(local_path, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses)
|
||||
tracker.save_file(local_path)
|
||||
tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)
|
||||
|
||||
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
||||
def recall_trainer(tracker: Tracker, trainer: DecoderTrainer):
|
||||
"""
|
||||
Loads the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
trainer.accelerator.print(print_ribbon(f"Loading model from {recall_source}"))
|
||||
local_filepath = tracker.recall_file(recall_source, **load_config)
|
||||
state_dict = trainer.load(local_filepath)
|
||||
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0)
|
||||
trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}"))
|
||||
state_dict = tracker.recall()
|
||||
trainer.load_state_dict(state_dict, only_model=False, strict=True)
|
||||
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0)
|
||||
|
||||
def train(
|
||||
dataloaders,
|
||||
decoder,
|
||||
accelerator,
|
||||
tracker,
|
||||
decoder: Decoder,
|
||||
accelerator: Accelerator,
|
||||
tracker: Tracker,
|
||||
inference_device,
|
||||
load_config=None,
|
||||
evaluate_config=None,
|
||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||
validation_samples = None,
|
||||
epochs = 20,
|
||||
n_sample_images = 5,
|
||||
save_every_n_samples = 100000,
|
||||
save_all=False,
|
||||
save_latest=True,
|
||||
save_best=True,
|
||||
unet_training_mask=None,
|
||||
condition_on_text_encodings=False,
|
||||
**kwargs
|
||||
@@ -299,13 +291,13 @@ def train(
|
||||
val_sample = 0
|
||||
step = lambda: int(trainer.step.item())
|
||||
|
||||
if exists(load_config) and exists(load_config.source):
|
||||
start_epoch, validation_losses, next_task, recalled_sample = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config.dict())
|
||||
if tracker.loader is not None:
|
||||
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||
if next_task == 'train':
|
||||
sample = recalled_sample
|
||||
if next_task == 'val':
|
||||
val_sample = recalled_sample
|
||||
accelerator.print(f"Loaded model from {load_config.source} on epoch {start_epoch} with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
||||
accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
||||
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
|
||||
trainer.to(device=inference_device)
|
||||
|
||||
@@ -399,19 +391,14 @@ def train(
|
||||
}
|
||||
|
||||
if is_master:
|
||||
tracker.log(log_data, step=step(), verbose=True)
|
||||
tracker.log(log_data, step=step())
|
||||
|
||||
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
||||
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
|
||||
print("Saving snapshot")
|
||||
last_snapshot = sample
|
||||
# We need to know where the model should be saved
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
if save_all:
|
||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step()}.pth")
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||
@@ -486,7 +473,7 @@ def train(
|
||||
if is_master:
|
||||
unet_average_val_loss = all_average_val_losses.mean(dim=0)
|
||||
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
|
||||
tracker.log(val_loss_map, step=step(), verbose=True)
|
||||
tracker.log(val_loss_map, step=step())
|
||||
next_task = 'eval'
|
||||
|
||||
if next_task == 'eval':
|
||||
@@ -494,7 +481,7 @@ def train(
|
||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
||||
if is_master:
|
||||
tracker.log(evaluation, step=step(), verbose=True)
|
||||
tracker.log(evaluation, step=step())
|
||||
next_task = 'sample'
|
||||
val_sample = 0
|
||||
|
||||
@@ -509,22 +496,16 @@ def train(
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||
# Get the same paths
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
is_best = False
|
||||
if all_average_val_losses is not None:
|
||||
average_loss = all_average_val_losses.mean(dim=0).item()
|
||||
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
||||
save_paths.append("best.pth")
|
||||
if len(validation_losses) == 0 or average_loss < min(validation_losses):
|
||||
is_best = True
|
||||
validation_losses.append(average_loss)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)
|
||||
next_task = 'train'
|
||||
|
||||
def create_tracker(accelerator, config, config_path, tracker_type=None, data_path=None):
|
||||
"""
|
||||
Creates a tracker of the specified type and initializes special features based on the full config
|
||||
"""
|
||||
def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:
|
||||
tracker_config = config.tracker
|
||||
accelerator_config = {
|
||||
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
|
||||
@@ -532,40 +513,16 @@ def create_tracker(accelerator, config, config_path, tracker_type=None, data_pat
|
||||
"NumProcesses": accelerator.num_processes,
|
||||
"MixedPrecision": accelerator.mixed_precision
|
||||
}
|
||||
init_config = { "config": {**config.dict(), **accelerator_config} }
|
||||
data_path = data_path or tracker_config.data_path
|
||||
tracker_type = tracker_type or tracker_config.tracker_type
|
||||
|
||||
if tracker_type == "dummy":
|
||||
tracker = DummyTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
elif tracker_type == "console":
|
||||
tracker = ConsoleTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
elif tracker_type == "wandb":
|
||||
# We need to initialize the resume state here
|
||||
load_config = config.load
|
||||
if load_config.source == "wandb" and load_config.resume:
|
||||
# Then we are resuming the run load_config["run_path"]
|
||||
run_id = load_config.run_path.split("/")[-1]
|
||||
init_config["id"] = run_id
|
||||
init_config["resume"] = "must"
|
||||
|
||||
init_config["entity"] = tracker_config.wandb_entity
|
||||
init_config["project"] = tracker_config.wandb_project
|
||||
tracker = WandbTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
tracker.save_file(str(config_path.absolute()), str(config_path.parent.absolute()))
|
||||
else:
|
||||
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
|
||||
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
||||
tracker.save_config(config_path, config_name='decoder_config.json')
|
||||
return tracker
|
||||
|
||||
def initialize_training(config, config_path):
|
||||
def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
# Make sure if we are not loading, distributed models are initialized to the same values
|
||||
torch.manual_seed(config.seed)
|
||||
|
||||
# Set up accelerator for configurable distributed training
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
# Set up data
|
||||
@@ -592,7 +549,7 @@ def initialize_training(config, config_path):
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
|
||||
# Create and initialize the tracker if we are the master
|
||||
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
|
||||
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
|
||||
|
||||
has_img_embeddings = config.data.img_embeddings_url is not None
|
||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||
@@ -622,7 +579,6 @@ def initialize_training(config, config_path):
|
||||
train(dataloaders, decoder, accelerator,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
load_config=config.load,
|
||||
evaluate_config=config.evaluate,
|
||||
condition_on_text_encodings=conditioning_on_text,
|
||||
**config.train.dict(),
|
||||
|
||||
Reference in New Issue
Block a user