mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 20:54:22 +01:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
282c35930f | ||
|
|
27b0f7ca0d | ||
|
|
7b0edf9e42 | ||
|
|
a922a539de | ||
|
|
8f2466f1cd | ||
|
|
908ab83799 | ||
|
|
46a2558d53 | ||
|
|
86109646e3 | ||
|
|
6a11b9678b |
19
README.md
19
README.md
@@ -1112,15 +1112,6 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@inproceedings{Tu2022MaxViTMV,
|
|
||||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
|
||||||
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
|
||||||
year = {2022},
|
|
||||||
url = {https://arxiv.org/abs/2204.01697}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@article{Yu2021VectorquantizedIM,
|
@article{Yu2021VectorquantizedIM,
|
||||||
title = {Vector-quantized Image Modeling with Improved VQGAN},
|
title = {Vector-quantized Image Modeling with Improved VQGAN},
|
||||||
@@ -1189,4 +1180,14 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{Saharia2021PaletteID,
|
||||||
|
title = {Palette: Image-to-Image Diffusion Models},
|
||||||
|
author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},
|
||||||
|
journal = {ArXiv},
|
||||||
|
year = {2021},
|
||||||
|
volume = {abs/2111.05826}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -91,21 +91,83 @@ Each metric can be enabled by setting its configuration. The configuration keys
|
|||||||
|
|
||||||
**<ins>Tracker</ins>:**
|
**<ins>Tracker</ins>:**
|
||||||
|
|
||||||
Selects which tracker to use and configures it.
|
Selects how the experiment will be tracked.
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
|
| `data_path` | No | `./.tracker-data` | The path to the folder where temporary tracker data will be saved. |
|
||||||
| `data_path` | No | `./models` | Where the tracker will store local data. |
|
| `overwrite_data_path` | No | `False` | If true, the data path will be overwritten. Otherwise, you need to delete it yourself. |
|
||||||
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
|
| `log` | Yes | N/A | Logging configuration. |
|
||||||
|
| `load` | No | `None` | Checkpoint loading configuration. |
|
||||||
|
| `save` | Yes | N/A | Checkpoint/Model saving configuration. |
|
||||||
|
Tracking is split up into three sections:
|
||||||
|
* Log: Where to save run metadata and image output. Options are `console` or `wandb`.
|
||||||
|
* Load: Where to load a checkpoint from. Options are `local`, `url`, or `wandb`.
|
||||||
|
* Save: Where to save a checkpoint to. Options are `local`, `huggingface`, or `wandb`.
|
||||||
|
|
||||||
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
|
**Logging:**
|
||||||
|
|
||||||
**<ins>Load</ins>:**
|
If using `console` there is no further configuration than setting `log_type` to `console`.
|
||||||
|
|
||||||
Selects where to load a pretrained model from.
|
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
| `source` | No | `None` | Supports `file` or `wandb`. |
|
| `log_type` | Yes | N/A | Must be `console`. |
|
||||||
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
|
|
||||||
|
|
||||||
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).
|
If using `wandb`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `log_type` | Yes | N/A | Must be `wandb`. |
|
||||||
|
| `wandb_entity` | Yes | N/A | The wandb entity to log to. |
|
||||||
|
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
|
||||||
|
| `wandb_run_name` | No | `None` | The wandb run name. |
|
||||||
|
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
|
||||||
|
| `wandb_resume` | No | `False` | Whether to resume an old run. |
|
||||||
|
|
||||||
|
**Loading:**
|
||||||
|
|
||||||
|
If using `local`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `load_from` | Yes | N/A | Must be `local`. |
|
||||||
|
| `file_path` | Yes | N/A | The path to the checkpoint file. |
|
||||||
|
|
||||||
|
If using `url`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `load_from` | Yes | N/A | Must be `url`. |
|
||||||
|
| `url` | Yes | N/A | The url of the checkpoint file. |
|
||||||
|
|
||||||
|
If using `wandb`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `load_from` | Yes | N/A | Must be `wandb`. |
|
||||||
|
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the run that is being resumed. |
|
||||||
|
| `wandb_file_path` | Yes | N/A | The path to the checkpoint file in the W&B file system. |
|
||||||
|
|
||||||
|
**Saving:**
|
||||||
|
Unlike `log` and `load`, `save` may be an array of options so that you can save to different locations in a run.
|
||||||
|
|
||||||
|
All save locations have these configuration options
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
|
||||||
|
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
|
||||||
|
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
|
||||||
|
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
|
||||||
|
|
||||||
|
If using `local`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `save_to` | Yes | N/A | Must be `local`. |
|
||||||
|
|
||||||
|
If using `huggingface`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `save_to` | Yes | N/A | Must be `huggingface`. |
|
||||||
|
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
|
||||||
|
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
|
||||||
|
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
|
||||||
|
|
||||||
|
If using `wandb`
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `save_to` | Yes | N/A | Must be `wandb`. |
|
||||||
|
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the current run. You will almost always want this to be `None`. |
|
||||||
|
|||||||
@@ -80,20 +80,32 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tracker": {
|
"tracker": {
|
||||||
"tracker_type": "console",
|
"overwrite_data_path": true,
|
||||||
"data_path": "./models",
|
|
||||||
|
|
||||||
"wandb_entity": "",
|
"log": {
|
||||||
"wandb_project": "",
|
"log_type": "wandb",
|
||||||
|
|
||||||
"verbose": false
|
"wandb_entity": "your_wandb",
|
||||||
},
|
"wandb_project": "your_project",
|
||||||
"load": {
|
|
||||||
"source": null,
|
|
||||||
|
|
||||||
"run_path": "",
|
"verbose": true
|
||||||
"file_path": "",
|
},
|
||||||
|
|
||||||
"resume": false
|
"load": {
|
||||||
|
"load_from": null
|
||||||
|
},
|
||||||
|
|
||||||
|
"save": [{
|
||||||
|
"save_to": "wandb"
|
||||||
|
}, {
|
||||||
|
"save_to": "huggingface",
|
||||||
|
"huggingface_repo": "Veldrovive/test_model",
|
||||||
|
|
||||||
|
"save_all": true,
|
||||||
|
"save_latest": true,
|
||||||
|
"save_best": true,
|
||||||
|
|
||||||
|
"save_type": "model"
|
||||||
|
}]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,6 +45,11 @@ def exists(val):
|
|||||||
def identity(t, *args, **kwargs):
|
def identity(t, *args, **kwargs):
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
def first(arr, d = None):
|
||||||
|
if len(arr) == 0:
|
||||||
|
return d
|
||||||
|
return arr[0]
|
||||||
|
|
||||||
def maybe(fn):
|
def maybe(fn):
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def inner(x):
|
def inner(x):
|
||||||
@@ -351,7 +356,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
|
|||||||
steps = timesteps + 1
|
steps = timesteps + 1
|
||||||
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
|
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
|
||||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
alphas_cumprod = alphas_cumprod / first(alphas_cumprod)
|
||||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||||
return torch.clip(betas, 0, 0.999)
|
return torch.clip(betas, 0, 0.999)
|
||||||
|
|
||||||
@@ -1088,8 +1093,16 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
|
|
||||||
def Upsample(dim):
|
def ConvTransposeUpsample(dim, dim_out = None):
|
||||||
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
dim_out = default(dim_out, dim)
|
||||||
|
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
||||||
|
|
||||||
|
def NearestUpsample(dim, dim_out = None):
|
||||||
|
dim_out = default(dim_out, dim)
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
||||||
|
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||||
|
)
|
||||||
|
|
||||||
def Downsample(dim, *, dim_out = None):
|
def Downsample(dim, *, dim_out = None):
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
@@ -1166,7 +1179,7 @@ class ResnetBlock(nn.Module):
|
|||||||
self.block2 = Block(dim_out, dim_out, groups = groups)
|
self.block2 = Block(dim_out, dim_out, groups = groups)
|
||||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, cond = None, time_emb = None):
|
def forward(self, x, time_emb = None, cond = None):
|
||||||
|
|
||||||
scale_shift = None
|
scale_shift = None
|
||||||
if exists(self.time_mlp) and exists(time_emb):
|
if exists(self.time_mlp) and exists(time_emb):
|
||||||
@@ -1247,20 +1260,6 @@ class CrossAttention(nn.Module):
|
|||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
class GridAttention(nn.Module):
|
|
||||||
def __init__(self, *args, window_size = 8, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.window_size = window_size
|
|
||||||
self.attn = Attention(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h, w = x.shape[-2:]
|
|
||||||
wsz = self.window_size
|
|
||||||
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
|
|
||||||
out = self.attn(x)
|
|
||||||
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class LinearAttention(nn.Module):
|
class LinearAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1359,6 +1358,9 @@ class Unet(nn.Module):
|
|||||||
cross_embed_downsample = False,
|
cross_embed_downsample = False,
|
||||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||||
memory_efficient = False,
|
memory_efficient = False,
|
||||||
|
scale_skip_connection = False,
|
||||||
|
nearest_upsample = False,
|
||||||
|
final_conv_kernel_size = 1,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1440,6 +1442,10 @@ class Unet(nn.Module):
|
|||||||
self.max_text_len = max_text_len
|
self.max_text_len = max_text_len
|
||||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||||
|
|
||||||
|
# whether to scale skip connection, adopted in Imagen
|
||||||
|
|
||||||
|
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
|
||||||
|
|
||||||
# attention related params
|
# attention related params
|
||||||
|
|
||||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||||
@@ -1447,6 +1453,8 @@ class Unet(nn.Module):
|
|||||||
# resnet block klass
|
# resnet block klass
|
||||||
|
|
||||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||||
|
top_level_resnet_group = first(resnet_groups)
|
||||||
|
|
||||||
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
||||||
|
|
||||||
assert len(resnet_groups) == len(in_out)
|
assert len(resnet_groups) == len(in_out)
|
||||||
@@ -1457,23 +1465,36 @@ class Unet(nn.Module):
|
|||||||
if cross_embed_downsample:
|
if cross_embed_downsample:
|
||||||
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
||||||
|
|
||||||
|
# upsample klass
|
||||||
|
|
||||||
|
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
||||||
|
|
||||||
|
# give memory efficient unet an initial resnet block
|
||||||
|
|
||||||
|
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
|
|
||||||
self.downs = nn.ModuleList([])
|
self.downs = nn.ModuleList([])
|
||||||
self.ups = nn.ModuleList([])
|
self.ups = nn.ModuleList([])
|
||||||
num_resolutions = len(in_out)
|
num_resolutions = len(in_out)
|
||||||
|
|
||||||
|
skip_connect_dims = [] # keeping track of skip connection dimensions
|
||||||
|
|
||||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
||||||
is_first = ind == 0
|
is_first = ind == 0
|
||||||
is_last = ind >= (num_resolutions - 1)
|
is_last = ind >= (num_resolutions - 1)
|
||||||
layer_cond_dim = cond_dim if not is_first else None
|
layer_cond_dim = cond_dim if not is_first else None
|
||||||
|
|
||||||
|
dim_layer = dim_out if memory_efficient else dim_in
|
||||||
|
skip_connect_dims.append(dim_layer)
|
||||||
|
|
||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
||||||
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(LinearAttention(dim_layer, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||||
downsample_klass(dim_out) if not is_last and not memory_efficient else None
|
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
@@ -1486,17 +1507,17 @@ class Unet(nn.Module):
|
|||||||
is_last = ind >= (len(in_out) - 1)
|
is_last = ind >= (len(in_out) - 1)
|
||||||
layer_cond_dim = cond_dim if not is_last else None
|
layer_cond_dim = cond_dim if not is_last else None
|
||||||
|
|
||||||
|
skip_connect_dim = skip_connect_dims.pop()
|
||||||
|
|
||||||
self.ups.append(nn.ModuleList([
|
self.ups.append(nn.ModuleList([
|
||||||
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||||
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
|
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||||
]))
|
]))
|
||||||
|
|
||||||
self.final_conv = nn.Sequential(
|
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
||||||
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
|
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
||||||
nn.Conv2d(dim, self.channels_out, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# if the current settings for the unet are not correct
|
# if the current settings for the unet are not correct
|
||||||
# for cascading DDPM, then reinit the unet with the right settings
|
# for cascading DDPM, then reinit the unet with the right settings
|
||||||
@@ -1660,6 +1681,11 @@ class Unet(nn.Module):
|
|||||||
c = self.norm_cond(c)
|
c = self.norm_cond(c)
|
||||||
mid_c = self.norm_mid_cond(mid_c)
|
mid_c = self.norm_mid_cond(mid_c)
|
||||||
|
|
||||||
|
# initial resnet block
|
||||||
|
|
||||||
|
if exists(self.init_resnet_block):
|
||||||
|
x = self.init_resnet_block(x, t)
|
||||||
|
|
||||||
# go through the layers of the unet, down and up
|
# go through the layers of the unet, down and up
|
||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
@@ -1668,36 +1694,41 @@ class Unet(nn.Module):
|
|||||||
if exists(pre_downsample):
|
if exists(pre_downsample):
|
||||||
x = pre_downsample(x)
|
x = pre_downsample(x)
|
||||||
|
|
||||||
x = init_block(x, c, t)
|
x = init_block(x, t, c)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
|
hiddens.append(x)
|
||||||
|
|
||||||
for resnet_block in resnet_blocks:
|
for resnet_block in resnet_blocks:
|
||||||
x = resnet_block(x, c, t)
|
x = resnet_block(x, t, c)
|
||||||
|
hiddens.append(x)
|
||||||
hiddens.append(x)
|
|
||||||
|
|
||||||
if exists(post_downsample):
|
if exists(post_downsample):
|
||||||
x = post_downsample(x)
|
x = post_downsample(x)
|
||||||
|
|
||||||
x = self.mid_block1(x, mid_c, t)
|
x = self.mid_block1(x, t, mid_c)
|
||||||
|
|
||||||
if exists(self.mid_attn):
|
if exists(self.mid_attn):
|
||||||
x = self.mid_attn(x)
|
x = self.mid_attn(x)
|
||||||
|
|
||||||
x = self.mid_block2(x, mid_c, t)
|
x = self.mid_block2(x, t, mid_c)
|
||||||
|
|
||||||
|
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
||||||
|
|
||||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||||
x = torch.cat((x, hiddens.pop()), dim = 1)
|
x = connect_skip(x)
|
||||||
x = init_block(x, c, t)
|
x = init_block(x, t, c)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
|
|
||||||
for resnet_block in resnet_blocks:
|
for resnet_block in resnet_blocks:
|
||||||
x = resnet_block(x, c, t)
|
x = connect_skip(x)
|
||||||
|
x = resnet_block(x, t, c)
|
||||||
|
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
x = torch.cat((x, r), dim = 1)
|
x = torch.cat((x, r), dim = 1)
|
||||||
return self.final_conv(x)
|
|
||||||
|
x = self.final_resnet_block(x, t)
|
||||||
|
return self.to_out(x)
|
||||||
|
|
||||||
class LowresConditioner(nn.Module):
|
class LowresConditioner(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1764,7 +1795,7 @@ class Decoder(nn.Module):
|
|||||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||||
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
blur_sigma = 0.6, # cascading ddpm - blur sigma
|
||||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||||
clip_denoised = True,
|
clip_denoised = True,
|
||||||
clip_x_start = True,
|
clip_x_start = True,
|
||||||
@@ -2194,7 +2225,8 @@ class Decoder(nn.Module):
|
|||||||
image_embed = None,
|
image_embed = None,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
text_mask = None,
|
text_mask = None,
|
||||||
unet_number = None
|
unet_number = None,
|
||||||
|
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
|
||||||
):
|
):
|
||||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||||
unet_number = default(unet_number, 1)
|
unet_number = default(unet_number, 1)
|
||||||
@@ -2244,7 +2276,12 @@ class Decoder(nn.Module):
|
|||||||
image = vae.encode(image)
|
image = vae.encode(image)
|
||||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||||
|
|
||||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||||
|
|
||||||
|
if not return_lowres_cond_image:
|
||||||
|
return losses
|
||||||
|
|
||||||
|
return losses, lowres_cond_img
|
||||||
|
|
||||||
# main class
|
# main class
|
||||||
|
|
||||||
@@ -2292,6 +2329,6 @@ class DALLE2(nn.Module):
|
|||||||
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
||||||
|
|
||||||
if one_text:
|
if one_text:
|
||||||
return images[0]
|
return first(images)
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
|
import urllib.request
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import importlib
|
import shutil
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
|
from typing import Optional, List, Union
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from dalle2_pytorch.utils import import_or_print_error
|
from dalle2_pytorch.utils import import_or_print_error
|
||||||
|
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
|
|
||||||
@@ -27,126 +30,484 @@ def load_wandb_file(run_path, file_path, **kwargs):
|
|||||||
def load_local_file(file_path, **kwargs):
|
def load_local_file(file_path, **kwargs):
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
# base class
|
class BaseLogger:
|
||||||
|
"""
|
||||||
class BaseTracker(nn.Module):
|
An abstract class representing an object that can log data.
|
||||||
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
Parameters:
|
||||||
super().__init__()
|
data_path (str): A file path for storing temporary data.
|
||||||
|
verbose (bool): Whether of not to always print logs to the console.
|
||||||
|
"""
|
||||||
|
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
|
||||||
self.data_path = Path(data_path)
|
self.data_path = Path(data_path)
|
||||||
self.data_path.mkdir(parents = True, exist_ok = True)
|
self.verbose = verbose
|
||||||
|
|
||||||
def init(self, config, **kwargs):
|
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def log(self, log, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def log_images(self, images, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def recall_state_dict(self, recall_source, *args, **kwargs):
|
|
||||||
"""
|
"""
|
||||||
Loads a state dict from any source.
|
Initializes the logger.
|
||||||
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
|
Errors if the logger is invalid.
|
||||||
this should not be linked to any individual tracker.
|
|
||||||
"""
|
"""
|
||||||
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
|
|
||||||
if recall_source == 'wandb':
|
|
||||||
return torch.load(load_wandb_file(*args, **kwargs))
|
|
||||||
elif recall_source == 'local':
|
|
||||||
return torch.load(load_local_file(*args, **kwargs))
|
|
||||||
else:
|
|
||||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
|
||||||
|
|
||||||
def save_file(self, file_path, **kwargs):
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def recall_file(self, recall_source, *args, **kwargs):
|
def log(self, log, **kwargs) -> None:
|
||||||
if recall_source == 'wandb':
|
raise NotImplementedError
|
||||||
return load_wandb_file(*args, **kwargs)
|
|
||||||
elif recall_source == 'local':
|
|
||||||
return load_local_file(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
|
||||||
|
|
||||||
# Tracker that no-ops all calls except for recall
|
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
class DummyTracker(BaseTracker):
|
def log_file(self, file_path, **kwargs) -> None:
|
||||||
def __init__(self, *args, **kwargs):
|
raise NotImplementedError
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def init(self, config, **kwargs):
|
def log_error(self, error_string, **kwargs) -> None:
|
||||||
pass
|
raise NotImplementedError
|
||||||
|
|
||||||
def log(self, log, **kwargs):
|
class ConsoleLogger(BaseLogger):
|
||||||
pass
|
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||||
|
print("Logging to console")
|
||||||
|
|
||||||
def log_images(self, images, **kwargs):
|
def log(self, log, **kwargs) -> None:
|
||||||
pass
|
|
||||||
|
|
||||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def save_file(self, file_path, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# basic stdout class
|
|
||||||
|
|
||||||
class ConsoleTracker(BaseTracker):
|
|
||||||
def init(self, **config):
|
|
||||||
print(config)
|
|
||||||
|
|
||||||
def log(self, log, **kwargs):
|
|
||||||
print(log)
|
print(log)
|
||||||
|
|
||||||
def log_images(self, images, **kwargs): # noop for logging images
|
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||||
pass
|
|
||||||
|
|
||||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
|
||||||
torch.save(state_dict, str(self.data_path / relative_path))
|
|
||||||
|
|
||||||
def save_file(self, file_path, **kwargs):
|
|
||||||
# This is a no-op for local file systems since it is already saved locally
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# basic wandb class
|
def log_file(self, file_path, **kwargs) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
class WandbTracker(BaseTracker):
|
def log_error(self, error_string, **kwargs) -> None:
|
||||||
def __init__(self, *args, **kwargs):
|
print(error_string)
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
|
class WandbLogger(BaseLogger):
|
||||||
|
"""
|
||||||
|
Logs to a wandb run.
|
||||||
|
Parameters:
|
||||||
|
data_path (str): A file path for storing temporary data.
|
||||||
|
wandb_entity (str): The wandb entity to log to.
|
||||||
|
wandb_project (str): The wandb project to log to.
|
||||||
|
wandb_run_id (str): The wandb run id to resume.
|
||||||
|
wandb_run_name (str): The wandb run name to use.
|
||||||
|
wandb_resume (bool): Whether to resume a wandb run.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
data_path: str,
|
||||||
|
wandb_entity: str,
|
||||||
|
wandb_project: str,
|
||||||
|
wandb_run_id: Optional[str] = None,
|
||||||
|
wandb_run_name: Optional[str] = None,
|
||||||
|
wandb_resume: bool = False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
self.entity = wandb_entity
|
||||||
|
self.project = wandb_project
|
||||||
|
self.run_id = wandb_run_id
|
||||||
|
self.run_name = wandb_run_name
|
||||||
|
self.resume = wandb_resume
|
||||||
|
|
||||||
|
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||||
|
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
|
||||||
|
assert self.project is not None, "wandb_project must be specified for wandb logger"
|
||||||
|
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
|
||||||
os.environ["WANDB_SILENT"] = "true"
|
os.environ["WANDB_SILENT"] = "true"
|
||||||
|
# Initializes the wandb run
|
||||||
|
init_object = {
|
||||||
|
"entity": self.entity,
|
||||||
|
"project": self.project,
|
||||||
|
"config": {**full_config.dict(), **extra_config}
|
||||||
|
}
|
||||||
|
if self.run_name is not None:
|
||||||
|
init_object['name'] = self.run_name
|
||||||
|
if self.resume:
|
||||||
|
assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'
|
||||||
|
if self.run_name is not None:
|
||||||
|
print("You are renaming a run. I hope that is what you intended.")
|
||||||
|
init_object['resume'] = 'must'
|
||||||
|
init_object['id'] = self.run_id
|
||||||
|
|
||||||
def init(self, **config):
|
self.wandb.init(**init_object)
|
||||||
self.wandb.init(**config)
|
print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}")
|
||||||
|
|
||||||
def log(self, log, verbose=False, **kwargs):
|
def log(self, log, **kwargs) -> None:
|
||||||
if verbose:
|
if self.verbose:
|
||||||
print(log)
|
print(log)
|
||||||
self.wandb.log(log, **kwargs)
|
self.wandb.log(log, **kwargs)
|
||||||
|
|
||||||
def log_images(self, images, captions=[], image_section="images", **kwargs):
|
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Takes a tensor of images and a list of captions and logs them to wandb.
|
Takes a tensor of images and a list of captions and logs them to wandb.
|
||||||
"""
|
"""
|
||||||
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
||||||
self.log({ image_section: wandb_images }, **kwargs)
|
self.wandb.log({ image_section: wandb_images }, **kwargs)
|
||||||
|
|
||||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
|
||||||
"""
|
|
||||||
Saves a state_dict to disk and uploads it
|
|
||||||
"""
|
|
||||||
full_path = str(self.data_path / relative_path)
|
|
||||||
torch.save(state_dict, full_path)
|
|
||||||
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
|
|
||||||
|
|
||||||
def save_file(self, file_path, base_path=None, **kwargs):
|
def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:
|
||||||
"""
|
|
||||||
Uploads a file from disk to wandb
|
|
||||||
"""
|
|
||||||
if base_path is None:
|
if base_path is None:
|
||||||
base_path = self.data_path
|
# Then we take the basepath as the parent of the file_path
|
||||||
|
base_path = Path(file_path).parent
|
||||||
self.wandb.save(str(file_path), base_path = str(base_path))
|
self.wandb.save(str(file_path), base_path = str(base_path))
|
||||||
|
|
||||||
|
def log_error(self, error_string, step=None, **kwargs) -> None:
|
||||||
|
if self.verbose:
|
||||||
|
print(error_string)
|
||||||
|
self.wandb.log({"error": error_string, **kwargs}, step=step)
|
||||||
|
|
||||||
|
logger_type_map = {
|
||||||
|
'console': ConsoleLogger,
|
||||||
|
'wandb': WandbLogger,
|
||||||
|
}
|
||||||
|
def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:
|
||||||
|
if logger_type == 'custom':
|
||||||
|
raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')
|
||||||
|
try:
|
||||||
|
logger_class = logger_type_map[logger_type]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')
|
||||||
|
return logger_class(data_path, **kwargs)
|
||||||
|
|
||||||
|
class BaseLoader:
|
||||||
|
"""
|
||||||
|
An abstract class representing an object that can load a model checkpoint.
|
||||||
|
Parameters:
|
||||||
|
data_path (str): A file path for storing temporary data.
|
||||||
|
"""
|
||||||
|
def __init__(self, data_path: str, **kwargs):
|
||||||
|
self.data_path = Path(data_path)
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def recall() -> dict:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class UrlLoader(BaseLoader):
|
||||||
|
"""
|
||||||
|
A loader that downloads the file from a url and loads it
|
||||||
|
Parameters:
|
||||||
|
data_path (str): A file path for storing temporary data.
|
||||||
|
url (str): The url to download the file from.
|
||||||
|
"""
|
||||||
|
def __init__(self, data_path: str, url: str, **kwargs):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
self.url = url
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
# Makes sure the file exists to be downloaded
|
||||||
|
pass # TODO: Actually implement that
|
||||||
|
|
||||||
|
def recall(self) -> dict:
|
||||||
|
# Download the file
|
||||||
|
save_path = self.data_path / 'loaded_checkpoint.pth'
|
||||||
|
urllib.request.urlretrieve(self.url, str(save_path))
|
||||||
|
# Load the file
|
||||||
|
return torch.load(str(save_path), map_location='cpu')
|
||||||
|
|
||||||
|
|
||||||
|
class LocalLoader(BaseLoader):
|
||||||
|
"""
|
||||||
|
A loader that loads a file from a local path
|
||||||
|
Parameters:
|
||||||
|
data_path (str): A file path for storing temporary data.
|
||||||
|
file_path (str): The path to the file to load.
|
||||||
|
"""
|
||||||
|
def __init__(self, data_path: str, file_path: str, **kwargs):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
self.file_path = Path(file_path)
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
# Makes sure the file exists to be loaded
|
||||||
|
if not self.file_path.exists():
|
||||||
|
raise FileNotFoundError(f'Model not found at {self.file_path}')
|
||||||
|
|
||||||
|
def recall(self) -> dict:
|
||||||
|
# Load the file
|
||||||
|
return torch.load(str(self.file_path), map_location='cpu')
|
||||||
|
|
||||||
|
class WandbLoader(BaseLoader):
|
||||||
|
"""
|
||||||
|
A loader that loads a model from an existing wandb run
|
||||||
|
"""
|
||||||
|
def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
self.run_path = wandb_run_path
|
||||||
|
self.file_path = wandb_file_path
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||||
|
# Make sure the file can be downloaded
|
||||||
|
if self.wandb.run is not None and self.run_path is None:
|
||||||
|
self.run_path = self.wandb.run.path
|
||||||
|
assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'
|
||||||
|
assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'
|
||||||
|
assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'
|
||||||
|
|
||||||
|
os.environ["WANDB_SILENT"] = "true"
|
||||||
|
pass # TODO: Actually implement that
|
||||||
|
|
||||||
|
def recall(self) -> dict:
|
||||||
|
file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)
|
||||||
|
return torch.load(file_reference.name, map_location='cpu')
|
||||||
|
|
||||||
|
loader_type_map = {
|
||||||
|
'url': UrlLoader,
|
||||||
|
'local': LocalLoader,
|
||||||
|
'wandb': WandbLoader,
|
||||||
|
}
|
||||||
|
def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
|
||||||
|
if loader_type == 'custom':
|
||||||
|
raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')
|
||||||
|
try:
|
||||||
|
loader_class = loader_type_map[loader_type]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')
|
||||||
|
return loader_class(data_path, **kwargs)
|
||||||
|
|
||||||
|
class BaseSaver:
|
||||||
|
def __init__(self,
|
||||||
|
data_path: str,
|
||||||
|
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
|
||||||
|
save_best_to: Optional[Union[str, bool]] = 'best.pth',
|
||||||
|
save_meta_to: str = './',
|
||||||
|
save_type: str = 'checkpoint',
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.data_path = Path(data_path)
|
||||||
|
self.save_latest_to = save_latest_to
|
||||||
|
self.saving_latest = save_latest_to is not None and save_latest_to is not False
|
||||||
|
self.save_best_to = save_best_to
|
||||||
|
self.saving_best = save_best_to is not None and save_best_to is not False
|
||||||
|
self.save_meta_to = save_meta_to
|
||||||
|
self.save_type = save_type
|
||||||
|
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
|
||||||
|
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
|
||||||
|
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Save a general file under save_meta_to
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class LocalSaver(BaseSaver):
|
||||||
|
def __init__(self,
|
||||||
|
data_path: str,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
# Makes sure the directory exists to be saved to
|
||||||
|
print(f"Saving {self.save_type} locally")
|
||||||
|
if not self.data_path.exists():
|
||||||
|
self.data_path.mkdir(parents=True)
|
||||||
|
|
||||||
|
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
|
||||||
|
# Copy the file to save_path
|
||||||
|
save_path_file_name = Path(save_path).name
|
||||||
|
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
|
||||||
|
shutil.copy(local_path, save_path)
|
||||||
|
|
||||||
|
class WandbSaver(BaseSaver):
|
||||||
|
def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
self.run_path = wandb_run_path
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
|
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
|
||||||
|
os.environ["WANDB_SILENT"] = "true"
|
||||||
|
# Makes sure that the user can upload tot his run
|
||||||
|
if self.run_path is not None:
|
||||||
|
entity, project, run_id = self.run_path.split("/")
|
||||||
|
self.run = self.wandb.init(entity=entity, project=project, id=run_id)
|
||||||
|
else:
|
||||||
|
assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'
|
||||||
|
self.run = self.wandb.run
|
||||||
|
# TODO: Now actually check if upload is possible
|
||||||
|
print(f"Saving to wandb run {self.run.path}-{self.run.name}")
|
||||||
|
|
||||||
|
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
|
||||||
|
# In order to log something in the correct place in wandb, we need to have the same file structure here
|
||||||
|
save_path_file_name = Path(save_path).name
|
||||||
|
print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}")
|
||||||
|
save_path = Path(self.data_path) / save_path
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(local_path, save_path)
|
||||||
|
self.run.save(str(save_path), base_path = str(self.data_path), policy='now')
|
||||||
|
|
||||||
|
class HuggingfaceSaver(BaseSaver):
|
||||||
|
def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(data_path, **kwargs)
|
||||||
|
self.huggingface_repo = huggingface_repo
|
||||||
|
self.token_path = token_path
|
||||||
|
|
||||||
|
def init(self, logger: BaseLogger, **kwargs):
|
||||||
|
# Makes sure this user can upload to the repo
|
||||||
|
self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')
|
||||||
|
try:
|
||||||
|
identity = self.hub.whoami() # Errors if not logged in
|
||||||
|
# Then we are logged in
|
||||||
|
except:
|
||||||
|
# We are not logged in. Use the token_path to set the token.
|
||||||
|
if not os.path.exists(self.token_path):
|
||||||
|
raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.")
|
||||||
|
with open(self.token_path, "r") as f:
|
||||||
|
token = f.read().strip()
|
||||||
|
self.hub.HfApi.set_access_token(token)
|
||||||
|
identity = self.hub.whoami()
|
||||||
|
print(f"Saving to huggingface repo {self.huggingface_repo}")
|
||||||
|
|
||||||
|
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
|
||||||
|
# Saving to huggingface is easy, we just need to upload the file with the correct name
|
||||||
|
save_path_file_name = Path(save_path).name
|
||||||
|
print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}")
|
||||||
|
self.hub.upload_file(
|
||||||
|
path_or_fileobj=str(local_path),
|
||||||
|
path_in_repo=str(save_path),
|
||||||
|
repo_id=self.huggingface_repo
|
||||||
|
)
|
||||||
|
|
||||||
|
saver_type_map = {
|
||||||
|
'local': LocalSaver,
|
||||||
|
'wandb': WandbSaver,
|
||||||
|
'huggingface': HuggingfaceSaver
|
||||||
|
}
|
||||||
|
def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:
|
||||||
|
if saver_type == 'custom':
|
||||||
|
raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')
|
||||||
|
try:
|
||||||
|
saver_class = saver_type_map[saver_type]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')
|
||||||
|
return saver_class(data_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Tracker:
|
||||||
|
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
|
||||||
|
self.data_path = Path(data_path)
|
||||||
|
if not dummy_mode:
|
||||||
|
if overwrite_data_path:
|
||||||
|
if self.data_path.exists():
|
||||||
|
shutil.rmtree(self.data_path)
|
||||||
|
self.data_path.mkdir(parents=True)
|
||||||
|
else:
|
||||||
|
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
|
||||||
|
if not self.data_path.exists():
|
||||||
|
self.data_path.mkdir(parents=True)
|
||||||
|
self.logger: BaseLogger = None
|
||||||
|
self.loader: Optional[BaseLoader] = None
|
||||||
|
self.savers: List[BaseSaver]= []
|
||||||
|
self.dummy_mode = dummy_mode
|
||||||
|
|
||||||
|
def init(self, full_config: BaseModel, extra_config: dict):
|
||||||
|
assert self.logger is not None, '`logger` must be set before `init` is called'
|
||||||
|
if self.dummy_mode:
|
||||||
|
# The only thing we need is a loader
|
||||||
|
if self.loader is not None:
|
||||||
|
self.loader.init(self.logger)
|
||||||
|
return
|
||||||
|
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
|
||||||
|
self.logger.init(full_config, extra_config)
|
||||||
|
if self.loader is not None:
|
||||||
|
self.loader.init(self.logger)
|
||||||
|
for saver in self.savers:
|
||||||
|
saver.init(self.logger)
|
||||||
|
|
||||||
|
def add_logger(self, logger: BaseLogger):
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def add_loader(self, loader: BaseLoader):
|
||||||
|
self.loader = loader
|
||||||
|
|
||||||
|
def add_saver(self, saver: BaseSaver):
|
||||||
|
self.savers.append(saver)
|
||||||
|
|
||||||
|
def log(self, *args, **kwargs):
|
||||||
|
if self.dummy_mode:
|
||||||
|
return
|
||||||
|
self.logger.log(*args, **kwargs)
|
||||||
|
|
||||||
|
def log_images(self, *args, **kwargs):
|
||||||
|
if self.dummy_mode:
|
||||||
|
return
|
||||||
|
self.logger.log_images(*args, **kwargs)
|
||||||
|
|
||||||
|
def log_file(self, *args, **kwargs):
|
||||||
|
if self.dummy_mode:
|
||||||
|
return
|
||||||
|
self.logger.log_file(*args, **kwargs)
|
||||||
|
|
||||||
|
def save_config(self, current_config_path: str, config_name = 'config.json'):
|
||||||
|
if self.dummy_mode:
|
||||||
|
return
|
||||||
|
# Save the config under config_name in the root folder of data_path
|
||||||
|
shutil.copy(current_config_path, self.data_path / config_name)
|
||||||
|
for saver in self.savers:
|
||||||
|
remote_path = Path(saver.save_meta_to) / config_name
|
||||||
|
saver.save_file(current_config_path, str(remote_path))
|
||||||
|
|
||||||
|
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
|
||||||
|
"""
|
||||||
|
Gets the state dict to be saved and writes it to file_path.
|
||||||
|
If save_type is 'checkpoint', we save the entire trainer state dict.
|
||||||
|
If save_type is 'model', we save only the model state dict.
|
||||||
|
"""
|
||||||
|
assert save_type in ['checkpoint', 'model']
|
||||||
|
if save_type == 'checkpoint':
|
||||||
|
trainer.save(file_path, overwrite=True, **kwargs)
|
||||||
|
elif save_type == 'model':
|
||||||
|
if isinstance(trainer, DiffusionPriorTrainer):
|
||||||
|
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||||
|
state_dict = trainer.unwrap_model(prior).state_dict()
|
||||||
|
torch.save(state_dict, file_path)
|
||||||
|
elif isinstance(trainer, DecoderTrainer):
|
||||||
|
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||||
|
if trainer.use_ema:
|
||||||
|
trainable_unets = decoder.unets
|
||||||
|
decoder.unets = trainer.unets # Swap EMA unets in
|
||||||
|
state_dict = decoder.state_dict()
|
||||||
|
decoder.unets = trainable_unets # Swap back
|
||||||
|
else:
|
||||||
|
state_dict = decoder.state_dict()
|
||||||
|
torch.save(state_dict, file_path)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
|
||||||
|
return Path(file_path)
|
||||||
|
|
||||||
|
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
|
||||||
|
if self.dummy_mode:
|
||||||
|
return
|
||||||
|
if not is_best and not is_latest:
|
||||||
|
# Nothing to do
|
||||||
|
return
|
||||||
|
# Save the checkpoint and model to data_path
|
||||||
|
checkpoint_path = self.data_path / 'checkpoint.pth'
|
||||||
|
self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)
|
||||||
|
model_path = self.data_path / 'model.pth'
|
||||||
|
self._save_state_dict(trainer, 'model', model_path, **kwargs)
|
||||||
|
print("Saved cached models")
|
||||||
|
# Call the save methods on the savers
|
||||||
|
for saver in self.savers:
|
||||||
|
local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path
|
||||||
|
if saver.saving_latest and is_latest:
|
||||||
|
latest_checkpoint_path = saver.save_latest_to.format(**kwargs)
|
||||||
|
try:
|
||||||
|
saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||||
|
print(f'Error saving checkpoint: {e}')
|
||||||
|
if saver.saving_best and is_best:
|
||||||
|
best_checkpoint_path = saver.save_best_to.format(**kwargs)
|
||||||
|
try:
|
||||||
|
saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||||
|
print(f'Error saving checkpoint: {e}')
|
||||||
|
|
||||||
|
def recall(self):
|
||||||
|
if self.loader is not None:
|
||||||
|
return self.loader.recall()
|
||||||
|
else:
|
||||||
|
raise ValueError('No loader specified')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -15,6 +15,7 @@ from dalle2_pytorch.dalle2_pytorch import (
|
|||||||
DiffusionPriorNetwork,
|
DiffusionPriorNetwork,
|
||||||
XClipAdapter
|
XClipAdapter
|
||||||
)
|
)
|
||||||
|
from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
@@ -44,13 +45,66 @@ class TrainSplitConfig(BaseModel):
|
|||||||
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
|
class TrackerLogConfig(BaseModel):
|
||||||
|
log_type: str = 'console'
|
||||||
|
verbose: bool = False
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
# Each individual log type has it's own arguments that will be passed through the config
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
def create(self, data_path: str):
|
||||||
|
kwargs = self.dict()
|
||||||
|
return create_logger(self.log_type, data_path, **kwargs)
|
||||||
|
|
||||||
|
class TrackerLoadConfig(BaseModel):
|
||||||
|
load_from: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
def create(self, data_path: str):
|
||||||
|
kwargs = self.dict()
|
||||||
|
if self.load_from is None:
|
||||||
|
return None
|
||||||
|
return create_loader(self.load_from, data_path, **kwargs)
|
||||||
|
|
||||||
|
class TrackerSaveConfig(BaseModel):
|
||||||
|
save_to: str = 'local'
|
||||||
|
save_all: bool = False
|
||||||
|
save_latest: bool = True
|
||||||
|
save_best: bool = True
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
def create(self, data_path: str):
|
||||||
|
kwargs = self.dict()
|
||||||
|
return create_saver(self.save_to, data_path, **kwargs)
|
||||||
|
|
||||||
class TrackerConfig(BaseModel):
|
class TrackerConfig(BaseModel):
|
||||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
data_path: str = '.tracker_data'
|
||||||
data_path: str = './models' # The path where files will be saved locally
|
overwrite_data_path: bool = False
|
||||||
init_config: Dict[str, Any] = None
|
log: TrackerLogConfig
|
||||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
load: Optional[TrackerLoadConfig]
|
||||||
wandb_project: str = ''
|
save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
|
||||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
|
||||||
|
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
|
||||||
|
tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)
|
||||||
|
# Add the logger
|
||||||
|
tracker.add_logger(self.log.create(self.data_path))
|
||||||
|
# Add the loader
|
||||||
|
if self.load is not None:
|
||||||
|
tracker.add_loader(self.load.create(self.data_path))
|
||||||
|
# Add the saver or savers
|
||||||
|
if isinstance(self.save, list):
|
||||||
|
for save_config in self.save:
|
||||||
|
tracker.add_saver(save_config.create(self.data_path))
|
||||||
|
else:
|
||||||
|
tracker.add_saver(self.save.create(self.data_path))
|
||||||
|
# Initialize all the components and verify that all data is valid
|
||||||
|
tracker.init(full_config, extra_config)
|
||||||
|
return tracker
|
||||||
|
|
||||||
# diffusion prior pydantic classes
|
# diffusion prior pydantic classes
|
||||||
|
|
||||||
@@ -238,6 +292,7 @@ class DecoderTrainConfig(BaseModel):
|
|||||||
epochs: int = 20
|
epochs: int = 20
|
||||||
lr: SingularOrIterable(float) = 1e-4
|
lr: SingularOrIterable(float) = 1e-4
|
||||||
wd: SingularOrIterable(float) = 0.01
|
wd: SingularOrIterable(float) = 0.01
|
||||||
|
find_unused_parameters: bool = True
|
||||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||||
save_every_n_samples: int = 100000
|
save_every_n_samples: int = 100000
|
||||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||||
@@ -247,9 +302,6 @@ class DecoderTrainConfig(BaseModel):
|
|||||||
use_ema: bool = True
|
use_ema: bool = True
|
||||||
ema_beta: float = 0.999
|
ema_beta: float = 0.999
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
save_all: bool = False # Whether to preserve all checkpoints
|
|
||||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
|
||||||
save_best: bool = True # Whether to save the best checkpoint
|
|
||||||
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||||
|
|
||||||
class DecoderEvaluateConfig(BaseModel):
|
class DecoderEvaluateConfig(BaseModel):
|
||||||
@@ -271,7 +323,6 @@ class TrainDecoderConfig(BaseModel):
|
|||||||
train: DecoderTrainConfig
|
train: DecoderTrainConfig
|
||||||
evaluate: DecoderEvaluateConfig
|
evaluate: DecoderEvaluateConfig
|
||||||
tracker: TrackerConfig
|
tracker: TrackerConfig
|
||||||
load: DecoderLoadConfig
|
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -289,22 +340,22 @@ class TrainDecoderConfig(BaseModel):
|
|||||||
# Then something else errored and we should just pass through
|
# Then something else errored and we should just pass through
|
||||||
return values
|
return values
|
||||||
|
|
||||||
using_text_encodings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
|
using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
|
||||||
using_clip = exists(decoder_config.clip)
|
using_clip = exists(decoder_config.clip)
|
||||||
img_emb_url = data_config.img_embeddings_url
|
img_emb_url = data_config.img_embeddings_url
|
||||||
text_emb_url = data_config.text_embeddings_url
|
text_emb_url = data_config.text_embeddings_url
|
||||||
|
|
||||||
if using_text_embeddings:
|
if using_text_encodings:
|
||||||
# Then we need some way to get the embeddings
|
# Then we need some way to get the embeddings
|
||||||
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
|
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
|
||||||
|
|
||||||
if using_clip:
|
if using_clip:
|
||||||
if using_text_embeddings:
|
if using_text_encodings:
|
||||||
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
|
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
|
||||||
else:
|
else:
|
||||||
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||||
|
|
||||||
if text_emb_url:
|
if text_emb_url:
|
||||||
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
|
assert using_text_encodings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|||||||
@@ -505,12 +505,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
self.accelerator.save(save_obj, str(path))
|
self.accelerator.save(save_obj, str(path))
|
||||||
|
|
||||||
def load(self, path, only_model = False, strict = True):
|
def load_state_dict(self, loaded_obj, only_model = False, strict = True):
|
||||||
path = Path(path)
|
|
||||||
assert path.exists()
|
|
||||||
|
|
||||||
loaded_obj = torch.load(str(path), map_location = 'cpu')
|
|
||||||
|
|
||||||
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
||||||
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||||
|
|
||||||
@@ -530,6 +525,14 @@ class DecoderTrainer(nn.Module):
|
|||||||
assert 'ema' in loaded_obj
|
assert 'ema' in loaded_obj
|
||||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
|
|
||||||
|
def load(self, path, only_model = False, strict = True):
|
||||||
|
path = Path(path)
|
||||||
|
assert path.exists()
|
||||||
|
|
||||||
|
loaded_obj = torch.load(str(path), map_location = 'cpu')
|
||||||
|
|
||||||
|
self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)
|
||||||
|
|
||||||
return loaded_obj
|
return loaded_obj
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
import time
|
import time
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
# time helpers
|
# time helpers
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.12.2'
|
__version__ = '0.15.2'
|
||||||
|
|||||||
106
train_decoder.py
106
train_decoder.py
@@ -1,11 +1,12 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from dalle2_pytorch.trainer import DecoderTrainer
|
from dalle2_pytorch.trainer import DecoderTrainer
|
||||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
|
from dalle2_pytorch.trackers import Tracker
|
||||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig
|
||||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||||
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to
|
||||||
from clip import tokenize
|
from clip import tokenize
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
@@ -239,42 +240,33 @@ def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=Fa
|
|||||||
metrics[metric_name] = metrics_tensor[i].item()
|
metrics[metric_name] = metrics_tensor[i].item()
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, relative_paths):
|
def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):
|
||||||
"""
|
"""
|
||||||
Logs the model with an appropriate method depending on the tracker
|
Logs the model with an appropriate method depending on the tracker
|
||||||
"""
|
"""
|
||||||
if isinstance(relative_paths, str):
|
tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)
|
||||||
relative_paths = [relative_paths]
|
|
||||||
for relative_path in relative_paths:
|
|
||||||
local_path = str(tracker.data_path / relative_path)
|
|
||||||
trainer.save(local_path, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses)
|
|
||||||
tracker.save_file(local_path)
|
|
||||||
|
|
||||||
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
def recall_trainer(tracker: Tracker, trainer: DecoderTrainer):
|
||||||
"""
|
"""
|
||||||
Loads the model with an appropriate method depending on the tracker
|
Loads the model with an appropriate method depending on the tracker
|
||||||
"""
|
"""
|
||||||
trainer.accelerator.print(print_ribbon(f"Loading model from {recall_source}"))
|
trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}"))
|
||||||
local_filepath = tracker.recall_file(recall_source, **load_config)
|
state_dict = tracker.recall()
|
||||||
state_dict = trainer.load(local_filepath)
|
trainer.load_state_dict(state_dict, only_model=False, strict=True)
|
||||||
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0)
|
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
dataloaders,
|
dataloaders,
|
||||||
decoder,
|
decoder: Decoder,
|
||||||
accelerator,
|
accelerator: Accelerator,
|
||||||
tracker,
|
tracker: Tracker,
|
||||||
inference_device,
|
inference_device,
|
||||||
load_config=None,
|
|
||||||
evaluate_config=None,
|
evaluate_config=None,
|
||||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||||
validation_samples = None,
|
validation_samples = None,
|
||||||
epochs = 20,
|
epochs = 20,
|
||||||
n_sample_images = 5,
|
n_sample_images = 5,
|
||||||
save_every_n_samples = 100000,
|
save_every_n_samples = 100000,
|
||||||
save_all=False,
|
|
||||||
save_latest=True,
|
|
||||||
save_best=True,
|
|
||||||
unet_training_mask=None,
|
unet_training_mask=None,
|
||||||
condition_on_text_encodings=False,
|
condition_on_text_encodings=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -299,13 +291,13 @@ def train(
|
|||||||
val_sample = 0
|
val_sample = 0
|
||||||
step = lambda: int(trainer.step.item())
|
step = lambda: int(trainer.step.item())
|
||||||
|
|
||||||
if exists(load_config) and exists(load_config.source):
|
if tracker.loader is not None:
|
||||||
start_epoch, validation_losses, next_task, recalled_sample = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config.dict())
|
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||||
if next_task == 'train':
|
if next_task == 'train':
|
||||||
sample = recalled_sample
|
sample = recalled_sample
|
||||||
if next_task == 'val':
|
if next_task == 'val':
|
||||||
val_sample = recalled_sample
|
val_sample = recalled_sample
|
||||||
accelerator.print(f"Loaded model from {load_config.source} on epoch {start_epoch} with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
||||||
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
|
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
|
||||||
trainer.to(device=inference_device)
|
trainer.to(device=inference_device)
|
||||||
|
|
||||||
@@ -399,19 +391,14 @@ def train(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if is_master:
|
if is_master:
|
||||||
tracker.log(log_data, step=step(), verbose=True)
|
tracker.log(log_data, step=step())
|
||||||
|
|
||||||
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
||||||
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
|
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
|
||||||
print("Saving snapshot")
|
print("Saving snapshot")
|
||||||
last_snapshot = sample
|
last_snapshot = sample
|
||||||
# We need to know where the model should be saved
|
# We need to know where the model should be saved
|
||||||
save_paths = []
|
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
||||||
if save_latest:
|
|
||||||
save_paths.append("latest.pth")
|
|
||||||
if save_all:
|
|
||||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step()}.pth")
|
|
||||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
|
||||||
if exists(n_sample_images) and n_sample_images > 0:
|
if exists(n_sample_images) and n_sample_images > 0:
|
||||||
trainer.eval()
|
trainer.eval()
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||||
@@ -486,7 +473,7 @@ def train(
|
|||||||
if is_master:
|
if is_master:
|
||||||
unet_average_val_loss = all_average_val_losses.mean(dim=0)
|
unet_average_val_loss = all_average_val_losses.mean(dim=0)
|
||||||
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
|
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
|
||||||
tracker.log(val_loss_map, step=step(), verbose=True)
|
tracker.log(val_loss_map, step=step())
|
||||||
next_task = 'eval'
|
next_task = 'eval'
|
||||||
|
|
||||||
if next_task == 'eval':
|
if next_task == 'eval':
|
||||||
@@ -494,7 +481,7 @@ def train(
|
|||||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
||||||
if is_master:
|
if is_master:
|
||||||
tracker.log(evaluation, step=step(), verbose=True)
|
tracker.log(evaluation, step=step())
|
||||||
next_task = 'sample'
|
next_task = 'sample'
|
||||||
val_sample = 0
|
val_sample = 0
|
||||||
|
|
||||||
@@ -509,22 +496,16 @@ def train(
|
|||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||||
|
|
||||||
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||||
# Get the same paths
|
is_best = False
|
||||||
save_paths = []
|
|
||||||
if save_latest:
|
|
||||||
save_paths.append("latest.pth")
|
|
||||||
if all_average_val_losses is not None:
|
if all_average_val_losses is not None:
|
||||||
average_loss = all_average_val_losses.mean(dim=0).item()
|
average_loss = all_average_val_losses.mean(dim=0).item()
|
||||||
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
if len(validation_losses) == 0 or average_loss < min(validation_losses):
|
||||||
save_paths.append("best.pth")
|
is_best = True
|
||||||
validation_losses.append(average_loss)
|
validation_losses.append(average_loss)
|
||||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)
|
||||||
next_task = 'train'
|
next_task = 'train'
|
||||||
|
|
||||||
def create_tracker(accelerator, config, config_path, tracker_type=None, data_path=None):
|
def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:
|
||||||
"""
|
|
||||||
Creates a tracker of the specified type and initializes special features based on the full config
|
|
||||||
"""
|
|
||||||
tracker_config = config.tracker
|
tracker_config = config.tracker
|
||||||
accelerator_config = {
|
accelerator_config = {
|
||||||
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
|
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
|
||||||
@@ -532,40 +513,16 @@ def create_tracker(accelerator, config, config_path, tracker_type=None, data_pat
|
|||||||
"NumProcesses": accelerator.num_processes,
|
"NumProcesses": accelerator.num_processes,
|
||||||
"MixedPrecision": accelerator.mixed_precision
|
"MixedPrecision": accelerator.mixed_precision
|
||||||
}
|
}
|
||||||
init_config = { "config": {**config.dict(), **accelerator_config} }
|
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
||||||
data_path = data_path or tracker_config.data_path
|
tracker.save_config(config_path, config_name='decoder_config.json')
|
||||||
tracker_type = tracker_type or tracker_config.tracker_type
|
|
||||||
|
|
||||||
if tracker_type == "dummy":
|
|
||||||
tracker = DummyTracker(data_path)
|
|
||||||
tracker.init(**init_config)
|
|
||||||
elif tracker_type == "console":
|
|
||||||
tracker = ConsoleTracker(data_path)
|
|
||||||
tracker.init(**init_config)
|
|
||||||
elif tracker_type == "wandb":
|
|
||||||
# We need to initialize the resume state here
|
|
||||||
load_config = config.load
|
|
||||||
if load_config.source == "wandb" and load_config.resume:
|
|
||||||
# Then we are resuming the run load_config["run_path"]
|
|
||||||
run_id = load_config.run_path.split("/")[-1]
|
|
||||||
init_config["id"] = run_id
|
|
||||||
init_config["resume"] = "must"
|
|
||||||
|
|
||||||
init_config["entity"] = tracker_config.wandb_entity
|
|
||||||
init_config["project"] = tracker_config.wandb_project
|
|
||||||
tracker = WandbTracker(data_path)
|
|
||||||
tracker.init(**init_config)
|
|
||||||
tracker.save_file(str(config_path.absolute()), str(config_path.parent.absolute()))
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
|
|
||||||
return tracker
|
return tracker
|
||||||
|
|
||||||
def initialize_training(config, config_path):
|
def initialize_training(config: TrainDecoderConfig, config_path):
|
||||||
# Make sure if we are not loading, distributed models are initialized to the same values
|
# Make sure if we are not loading, distributed models are initialized to the same values
|
||||||
torch.manual_seed(config.seed)
|
torch.manual_seed(config.seed)
|
||||||
|
|
||||||
# Set up accelerator for configurable distributed training
|
# Set up accelerator for configurable distributed training
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||||
|
|
||||||
# Set up data
|
# Set up data
|
||||||
@@ -592,7 +549,7 @@ def initialize_training(config, config_path):
|
|||||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||||
|
|
||||||
# Create and initialize the tracker if we are the master
|
# Create and initialize the tracker if we are the master
|
||||||
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
|
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
|
||||||
|
|
||||||
has_img_embeddings = config.data.img_embeddings_url is not None
|
has_img_embeddings = config.data.img_embeddings_url is not None
|
||||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||||
@@ -622,7 +579,6 @@ def initialize_training(config, config_path):
|
|||||||
train(dataloaders, decoder, accelerator,
|
train(dataloaders, decoder, accelerator,
|
||||||
tracker=tracker,
|
tracker=tracker,
|
||||||
inference_device=accelerator.device,
|
inference_device=accelerator.device,
|
||||||
load_config=config.load,
|
|
||||||
evaluate_config=config.evaluate,
|
evaluate_config=config.evaluate,
|
||||||
condition_on_text_encodings=conditioning_on_text,
|
condition_on_text_encodings=conditioning_on_text,
|
||||||
**config.train.dict(),
|
**config.train.dict(),
|
||||||
|
|||||||
Reference in New Issue
Block a user