mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
88f516b5db | ||
|
|
b3e646fd3b | ||
|
|
6a59c7093d | ||
|
|
a6cdbe0b9c | ||
|
|
e928ae5c34 | ||
|
|
1bd8a7835a | ||
|
|
f33453df9f | ||
|
|
1e4bb2bafb | ||
|
|
ee75515c7d | ||
|
|
ec68243479 | ||
|
|
3afdcdfe86 | ||
|
|
b9a908ff75 | ||
|
|
e1fe3089df | ||
|
|
6d477d7654 | ||
|
|
531fe4b62f | ||
|
|
ec5a77fc55 | ||
|
|
fac63c61bc | ||
|
|
3d23ba4aa5 | ||
|
|
282c35930f | ||
|
|
27b0f7ca0d | ||
|
|
7b0edf9e42 | ||
|
|
a922a539de |
50
README.md
50
README.md
@@ -20,18 +20,20 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
|
||||
|
||||
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
|
||||
|
||||
<img src="./samples/oxford.png" width="600px" />
|
||||
<img src="./samples/oxford.png" width="450px" />
|
||||
|
||||
*ongoing at 21k steps*
|
||||
|
||||
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
|
||||
|
||||
- <a href="https://github.com/rom1504">Romain</a> has scaled up training to 800 GPUs with the available scripts without any issues
|
||||
|
||||
## Pre-Trained Models
|
||||
|
||||
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
|
||||
- DALL-E 2 🚧
|
||||
- DALL-E 2 🚧 - <a href="https://github.com/LAION-AI/dalle2-laion">DALL-E 2 Laion repository</a>
|
||||
|
||||
## Appreciation
|
||||
|
||||
@@ -579,7 +581,8 @@ unet1 = Unet(
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
@@ -596,8 +599,7 @@ decoder = Decoder(
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
@@ -988,34 +990,7 @@ dataset = ImageEmbeddingDataset(
|
||||
|
||||
#### `train_diffusion_prior.py`
|
||||
|
||||
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
|
||||
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
|
||||
|
||||
#### Usage
|
||||
|
||||
```bash
|
||||
$ python train_diffusion_prior.py
|
||||
```
|
||||
|
||||
The most significant parameters for the script are as follows:
|
||||
|
||||
- `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
|
||||
|
||||
- `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
|
||||
|
||||
- `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
|
||||
|
||||
- `learning-rate`, default = `1.1e-4`
|
||||
|
||||
- `weight-decay`, default = `6.02e-2`
|
||||
|
||||
- `max-grad-norm`, default = `0.5`
|
||||
|
||||
- `batch-size`, default = `10 ** 4`
|
||||
|
||||
- `num-epochs`, default = `5`
|
||||
|
||||
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
||||
For detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
@@ -1112,15 +1087,6 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022},
|
||||
url = {https://arxiv.org/abs/2204.01697}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Yu2021VectorquantizedIM,
|
||||
title = {Vector-quantized Image Modeling with Improved VQGAN},
|
||||
|
||||
@@ -91,21 +91,83 @@ Each metric can be enabled by setting its configuration. The configuration keys
|
||||
|
||||
**<ins>Tracker</ins>:**
|
||||
|
||||
Selects which tracker to use and configures it.
|
||||
Selects how the experiment will be tracked.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
|
||||
| `data_path` | No | `./models` | Where the tracker will store local data. |
|
||||
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
|
||||
| `data_path` | No | `./.tracker-data` | The path to the folder where temporary tracker data will be saved. |
|
||||
| `overwrite_data_path` | No | `False` | If true, the data path will be overwritten. Otherwise, you need to delete it yourself. |
|
||||
| `log` | Yes | N/A | Logging configuration. |
|
||||
| `load` | No | `None` | Checkpoint loading configuration. |
|
||||
| `save` | Yes | N/A | Checkpoint/Model saving configuration. |
|
||||
Tracking is split up into three sections:
|
||||
* Log: Where to save run metadata and image output. Options are `console` or `wandb`.
|
||||
* Load: Where to load a checkpoint from. Options are `local`, `url`, or `wandb`.
|
||||
* Save: Where to save a checkpoint to. Options are `local`, `huggingface`, or `wandb`.
|
||||
|
||||
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
|
||||
**Logging:**
|
||||
|
||||
**<ins>Load</ins>:**
|
||||
|
||||
Selects where to load a pretrained model from.
|
||||
If using `console` there is no further configuration than setting `log_type` to `console`.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `source` | No | `None` | Supports `file` or `wandb`. |
|
||||
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
|
||||
| `log_type` | Yes | N/A | Must be `console`. |
|
||||
|
||||
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).
|
||||
If using `wandb`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `log_type` | Yes | N/A | Must be `wandb`. |
|
||||
| `wandb_entity` | Yes | N/A | The wandb entity to log to. |
|
||||
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
|
||||
| `wandb_run_name` | No | `None` | The wandb run name. |
|
||||
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
|
||||
| `wandb_resume` | No | `False` | Whether to resume an old run. |
|
||||
|
||||
**Loading:**
|
||||
|
||||
If using `local`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `load_from` | Yes | N/A | Must be `local`. |
|
||||
| `file_path` | Yes | N/A | The path to the checkpoint file. |
|
||||
|
||||
If using `url`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `load_from` | Yes | N/A | Must be `url`. |
|
||||
| `url` | Yes | N/A | The url of the checkpoint file. |
|
||||
|
||||
If using `wandb`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `load_from` | Yes | N/A | Must be `wandb`. |
|
||||
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the run that is being resumed. |
|
||||
| `wandb_file_path` | Yes | N/A | The path to the checkpoint file in the W&B file system. |
|
||||
|
||||
**Saving:**
|
||||
Unlike `log` and `load`, `save` may be an array of options so that you can save to different locations in a run.
|
||||
|
||||
All save locations have these configuration options
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
|
||||
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
|
||||
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
|
||||
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
|
||||
|
||||
If using `local`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `local`. |
|
||||
|
||||
If using `huggingface`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `huggingface`. |
|
||||
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
|
||||
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
|
||||
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
|
||||
|
||||
If using `wandb`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `save_to` | Yes | N/A | Must be `wandb`. |
|
||||
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the current run. You will almost always want this to be `None`. |
|
||||
|
||||
@@ -80,20 +80,32 @@
|
||||
}
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "console",
|
||||
"data_path": "./models",
|
||||
"overwrite_data_path": true,
|
||||
|
||||
"wandb_entity": "",
|
||||
"wandb_project": "",
|
||||
"log": {
|
||||
"log_type": "wandb",
|
||||
|
||||
"verbose": false
|
||||
},
|
||||
"load": {
|
||||
"source": null,
|
||||
"wandb_entity": "your_wandb",
|
||||
"wandb_project": "your_project",
|
||||
|
||||
"run_path": "",
|
||||
"file_path": "",
|
||||
"verbose": true
|
||||
},
|
||||
|
||||
"resume": false
|
||||
"load": {
|
||||
"load_from": null
|
||||
},
|
||||
|
||||
"save": [{
|
||||
"save_to": "wandb"
|
||||
}, {
|
||||
"save_to": "huggingface",
|
||||
"huggingface_repo": "Veldrovive/test_model",
|
||||
|
||||
"save_all": true,
|
||||
"save_latest": true,
|
||||
"save_best": true,
|
||||
|
||||
"save_type": "model"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,11 +63,16 @@ def default(val, d):
|
||||
return val
|
||||
return d() if callable(d) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
def cast_tuple(val, length = None):
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
out = val if isinstance(val, tuple) else ((val,) * default(length, 1))
|
||||
|
||||
if exists(length):
|
||||
assert len(out) == length
|
||||
|
||||
return out
|
||||
|
||||
def module_device(module):
|
||||
return next(module.parameters()).device
|
||||
@@ -330,21 +335,25 @@ def approx_standard_normal_cdf(x):
|
||||
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||
assert x.shape == means.shape == log_scales.shape
|
||||
|
||||
# attempting to correct nan gradients when learned variance is turned on
|
||||
# in the setting of deepspeed fp16
|
||||
eps = 1e-12 if x.dtype == torch.float32 else 1e-3
|
||||
|
||||
centered_x = x - means
|
||||
inv_stdv = torch.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_x + 1. / 255.)
|
||||
cdf_plus = approx_standard_normal_cdf(plus_in)
|
||||
min_in = inv_stdv * (centered_x - 1. / 255.)
|
||||
cdf_min = approx_standard_normal_cdf(min_in)
|
||||
log_cdf_plus = log(cdf_plus)
|
||||
log_one_minus_cdf_min = log(1. - cdf_min)
|
||||
log_cdf_plus = log(cdf_plus, eps = eps)
|
||||
log_one_minus_cdf_min = log(1. - cdf_min, eps = eps)
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
log_probs = torch.where(x < -thres,
|
||||
log_cdf_plus,
|
||||
torch.where(x > thres,
|
||||
log_one_minus_cdf_min,
|
||||
log(cdf_delta)))
|
||||
log(cdf_delta, eps = eps)))
|
||||
|
||||
return log_probs
|
||||
|
||||
@@ -485,14 +494,16 @@ class NoiseScheduler(nn.Module):
|
||||
# diffusion prior
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.register_buffer("beta", torch.zeros(dim))
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
||||
|
||||
x = x / x.amax(dim = -1, keepdim = True).detach()
|
||||
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = -1, keepdim = True)
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
@@ -501,10 +512,10 @@ class ChanLayerNorm(nn.Module):
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = x / x.amax(dim = 1, keepdim = True).detach()
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g
|
||||
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
@@ -624,10 +635,13 @@ class Attention(nn.Module):
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
rotary_emb = None
|
||||
rotary_emb = None,
|
||||
pb_relax_alpha = 32 ** 2
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.pb_relax_alpha = pb_relax_alpha
|
||||
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
|
||||
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
@@ -691,7 +705,10 @@ class Attention(nn.Module):
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
sim = sim * self.pb_relax_alpha
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
@@ -1093,7 +1110,11 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
# decoder
|
||||
|
||||
def Upsample(dim, dim_out = None):
|
||||
def ConvTransposeUpsample(dim, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
||||
|
||||
def NearestUpsample(dim, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.Sequential(
|
||||
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
||||
@@ -1110,11 +1131,12 @@ class SinusoidalPosEmb(nn.Module):
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
dtype, device = x.dtype, x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
|
||||
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
||||
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
||||
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
@@ -1201,10 +1223,12 @@ class CrossAttention(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
norm_context = False
|
||||
norm_context = False,
|
||||
pb_relax_alpha = 32 ** 2
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.pb_relax_alpha = pb_relax_alpha
|
||||
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
@@ -1250,26 +1274,15 @@ class CrossAttention(nn.Module):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
sim = sim * self.pb_relax_alpha
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class GridAttention(nn.Module):
|
||||
def __init__(self, *args, window_size = 8, **kwargs):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.attn = Attention(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
wsz = self.window_size
|
||||
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
|
||||
out = self.attn(x)
|
||||
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
||||
return out
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1351,6 +1364,7 @@ class Unet(nn.Module):
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
channels_out = None,
|
||||
self_attn = False,
|
||||
attn_dim_head = 32,
|
||||
attn_heads = 16,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
@@ -1369,6 +1383,8 @@ class Unet(nn.Module):
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
memory_efficient = False,
|
||||
scale_skip_connection = False,
|
||||
nearest_upsample = False,
|
||||
final_conv_kernel_size = 1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1395,6 +1411,8 @@ class Unet(nn.Module):
|
||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
num_stages = len(in_out)
|
||||
|
||||
# time, image embeddings, and optional text encoding
|
||||
|
||||
cond_dim = default(cond_dim, dim)
|
||||
@@ -1458,14 +1476,16 @@ class Unet(nn.Module):
|
||||
|
||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||
|
||||
self_attn = cast_tuple(self_attn, num_stages)
|
||||
|
||||
create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs)))
|
||||
|
||||
# resnet block klass
|
||||
|
||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||
resnet_groups = cast_tuple(resnet_groups, num_stages)
|
||||
top_level_resnet_group = first(resnet_groups)
|
||||
|
||||
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_stages)
|
||||
|
||||
# downsample klass
|
||||
|
||||
@@ -1473,6 +1493,10 @@ class Unet(nn.Module):
|
||||
if cross_embed_downsample:
|
||||
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
||||
|
||||
# upsample klass
|
||||
|
||||
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
||||
|
||||
# give memory efficient unet an initial resnet block
|
||||
|
||||
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
|
||||
@@ -1483,9 +1507,9 @@ class Unet(nn.Module):
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
skip_connect_dims = [] # keeping track of skip connection dimensions
|
||||
skip_connect_dims = [] # keeping track of skip connection dimensions
|
||||
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
|
||||
is_first = ind == 0
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
@@ -1493,35 +1517,47 @@ class Unet(nn.Module):
|
||||
dim_layer = dim_out if memory_efficient else dim_in
|
||||
skip_connect_dims.append(dim_layer)
|
||||
|
||||
attention = nn.Identity()
|
||||
if layer_self_attn:
|
||||
attention = create_self_attn(dim_layer)
|
||||
elif sparse_attn:
|
||||
attention = Residual(LinearAttention(dim_layer, **attn_kwargs))
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
||||
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_layer, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
attention,
|
||||
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_attn = create_self_attn(mid_dim)
|
||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
skip_connect_dim = skip_connect_dims.pop()
|
||||
|
||||
attention = nn.Identity()
|
||||
if layer_self_attn:
|
||||
attention = create_self_attn(dim_out)
|
||||
elif sparse_attn:
|
||||
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
Upsample(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||
attention,
|
||||
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||
]))
|
||||
|
||||
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
||||
self.to_out = nn.Conv2d(dim, self.channels_out, 3, padding = 1)
|
||||
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
@@ -1595,6 +1631,7 @@ class Unet(nn.Module):
|
||||
|
||||
# time conditioning
|
||||
|
||||
time = time.type_as(x)
|
||||
time_hiddens = self.to_time_hiddens(time)
|
||||
|
||||
time_tokens = self.to_time_tokens(time_hiddens)
|
||||
@@ -1694,18 +1731,19 @@ class Unet(nn.Module):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
|
||||
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
|
||||
if exists(pre_downsample):
|
||||
x = pre_downsample(x)
|
||||
|
||||
x = init_block(x, t, c)
|
||||
x = sparse_attn(x)
|
||||
hiddens.append(x)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, t, c)
|
||||
hiddens.append(x)
|
||||
|
||||
x = attn(x)
|
||||
hiddens.append(x)
|
||||
|
||||
if exists(post_downsample):
|
||||
x = post_downsample(x)
|
||||
|
||||
@@ -1716,17 +1754,17 @@ class Unet(nn.Module):
|
||||
|
||||
x = self.mid_block2(x, t, mid_c)
|
||||
|
||||
connect_skip = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
||||
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
||||
|
||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||
for init_block, resnet_blocks, attn, upsample in self.ups:
|
||||
x = connect_skip(x)
|
||||
x = init_block(x, t, c)
|
||||
x = sparse_attn(x)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = connect_skip(x)
|
||||
x = resnet_block(x, t, c)
|
||||
|
||||
x = attn(x)
|
||||
x = upsample(x)
|
||||
|
||||
x = torch.cat((x, r), dim = 1)
|
||||
@@ -1892,10 +1930,6 @@ class Decoder(nn.Module):
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# determine from unets whether conditioning on text encoding is needed
|
||||
|
||||
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
|
||||
|
||||
# create noise schedulers per unet
|
||||
|
||||
if not exists(beta_schedule):
|
||||
@@ -1974,6 +2008,10 @@ class Decoder(nn.Module):
|
||||
def device(self):
|
||||
return self._dummy.device
|
||||
|
||||
@property
|
||||
def condition_on_text_encodings(self):
|
||||
return any([unet.cond_on_text_encodings for unet in self.unets])
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
@@ -2229,7 +2267,8 @@ class Decoder(nn.Module):
|
||||
image_embed = None,
|
||||
text_encodings = None,
|
||||
text_mask = None,
|
||||
unet_number = None
|
||||
unet_number = None,
|
||||
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
|
||||
):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
@@ -2279,7 +2318,12 @@ class Decoder(nn.Module):
|
||||
image = vae.encode(image)
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||
|
||||
if not return_lowres_cond_image:
|
||||
return losses
|
||||
|
||||
return losses, lowres_cond_img
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import urllib.request
|
||||
import os
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
import shutil
|
||||
from itertools import zip_longest
|
||||
from typing import Optional, List, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch.utils import import_or_print_error
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||
|
||||
# constants
|
||||
|
||||
@@ -27,126 +30,484 @@ def load_wandb_file(run_path, file_path, **kwargs):
|
||||
def load_local_file(file_path, **kwargs):
|
||||
return file_path
|
||||
|
||||
# base class
|
||||
|
||||
class BaseTracker(nn.Module):
|
||||
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
||||
super().__init__()
|
||||
class BaseLogger:
|
||||
"""
|
||||
An abstract class representing an object that can log data.
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
verbose (bool): Whether of not to always print logs to the console.
|
||||
"""
|
||||
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
self.data_path.mkdir(parents = True, exist_ok = True)
|
||||
self.verbose = verbose
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log_images(self, images, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def recall_state_dict(self, recall_source, *args, **kwargs):
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
"""
|
||||
Loads a state dict from any source.
|
||||
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
|
||||
this should not be linked to any individual tracker.
|
||||
Initializes the logger.
|
||||
Errors if the logger is invalid.
|
||||
"""
|
||||
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
|
||||
if recall_source == 'wandb':
|
||||
return torch.load(load_wandb_file(*args, **kwargs))
|
||||
elif recall_source == 'local':
|
||||
return torch.load(load_local_file(*args, **kwargs))
|
||||
else:
|
||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def recall_file(self, recall_source, *args, **kwargs):
|
||||
if recall_source == 'wandb':
|
||||
return load_wandb_file(*args, **kwargs)
|
||||
elif recall_source == 'local':
|
||||
return load_local_file(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
||||
def log(self, log, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
# Tracker that no-ops all calls except for recall
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
class DummyTracker(BaseTracker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def log_file(self, file_path, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
pass
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
pass
|
||||
class ConsoleLogger(BaseLogger):
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
print("Logging to console")
|
||||
|
||||
def log_images(self, images, **kwargs):
|
||||
pass
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
pass
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
pass
|
||||
|
||||
# basic stdout class
|
||||
|
||||
class ConsoleTracker(BaseTracker):
|
||||
def init(self, **config):
|
||||
print(config)
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
def log(self, log, **kwargs) -> None:
|
||||
print(log)
|
||||
|
||||
def log_images(self, images, **kwargs): # noop for logging images
|
||||
pass
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
torch.save(state_dict, str(self.data_path / relative_path))
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
# This is a no-op for local file systems since it is already saved locally
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
# basic wandb class
|
||||
def log_file(self, file_path, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
class WandbTracker(BaseTracker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
print(error_string)
|
||||
|
||||
class WandbLogger(BaseLogger):
|
||||
"""
|
||||
Logs to a wandb run.
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
wandb_entity (str): The wandb entity to log to.
|
||||
wandb_project (str): The wandb project to log to.
|
||||
wandb_run_id (str): The wandb run id to resume.
|
||||
wandb_run_name (str): The wandb run name to use.
|
||||
wandb_resume (bool): Whether to resume a wandb run.
|
||||
"""
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
wandb_entity: str,
|
||||
wandb_project: str,
|
||||
wandb_run_id: Optional[str] = None,
|
||||
wandb_run_name: Optional[str] = None,
|
||||
wandb_resume: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.entity = wandb_entity
|
||||
self.project = wandb_project
|
||||
self.run_id = wandb_run_id
|
||||
self.run_name = wandb_run_name
|
||||
self.resume = wandb_resume
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
|
||||
assert self.project is not None, "wandb_project must be specified for wandb logger"
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
# Initializes the wandb run
|
||||
init_object = {
|
||||
"entity": self.entity,
|
||||
"project": self.project,
|
||||
"config": {**full_config.dict(), **extra_config}
|
||||
}
|
||||
if self.run_name is not None:
|
||||
init_object['name'] = self.run_name
|
||||
if self.resume:
|
||||
assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'
|
||||
if self.run_name is not None:
|
||||
print("You are renaming a run. I hope that is what you intended.")
|
||||
init_object['resume'] = 'must'
|
||||
init_object['id'] = self.run_id
|
||||
|
||||
def init(self, **config):
|
||||
self.wandb.init(**config)
|
||||
self.wandb.init(**init_object)
|
||||
print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}")
|
||||
|
||||
def log(self, log, verbose=False, **kwargs):
|
||||
if verbose:
|
||||
def log(self, log, **kwargs) -> None:
|
||||
if self.verbose:
|
||||
print(log)
|
||||
self.wandb.log(log, **kwargs)
|
||||
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs):
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||
"""
|
||||
Takes a tensor of images and a list of captions and logs them to wandb.
|
||||
"""
|
||||
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
||||
self.log({ image_section: wandb_images }, **kwargs)
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
"""
|
||||
Saves a state_dict to disk and uploads it
|
||||
"""
|
||||
full_path = str(self.data_path / relative_path)
|
||||
torch.save(state_dict, full_path)
|
||||
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
|
||||
self.wandb.log({ image_section: wandb_images }, **kwargs)
|
||||
|
||||
def save_file(self, file_path, base_path=None, **kwargs):
|
||||
"""
|
||||
Uploads a file from disk to wandb
|
||||
"""
|
||||
def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:
|
||||
if base_path is None:
|
||||
base_path = self.data_path
|
||||
# Then we take the basepath as the parent of the file_path
|
||||
base_path = Path(file_path).parent
|
||||
self.wandb.save(str(file_path), base_path = str(base_path))
|
||||
|
||||
def log_error(self, error_string, step=None, **kwargs) -> None:
|
||||
if self.verbose:
|
||||
print(error_string)
|
||||
self.wandb.log({"error": error_string, **kwargs}, step=step)
|
||||
|
||||
logger_type_map = {
|
||||
'console': ConsoleLogger,
|
||||
'wandb': WandbLogger,
|
||||
}
|
||||
def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:
|
||||
if logger_type == 'custom':
|
||||
raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')
|
||||
try:
|
||||
logger_class = logger_type_map[logger_type]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')
|
||||
return logger_class(data_path, **kwargs)
|
||||
|
||||
class BaseLoader:
|
||||
"""
|
||||
An abstract class representing an object that can load a model checkpoint.
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
"""
|
||||
def __init__(self, data_path: str, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def recall() -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
class UrlLoader(BaseLoader):
|
||||
"""
|
||||
A loader that downloads the file from a url and loads it
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
url (str): The url to download the file from.
|
||||
"""
|
||||
def __init__(self, data_path: str, url: str, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.url = url
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the file exists to be downloaded
|
||||
pass # TODO: Actually implement that
|
||||
|
||||
def recall(self) -> dict:
|
||||
# Download the file
|
||||
save_path = self.data_path / 'loaded_checkpoint.pth'
|
||||
urllib.request.urlretrieve(self.url, str(save_path))
|
||||
# Load the file
|
||||
return torch.load(str(save_path), map_location='cpu')
|
||||
|
||||
|
||||
class LocalLoader(BaseLoader):
|
||||
"""
|
||||
A loader that loads a file from a local path
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
file_path (str): The path to the file to load.
|
||||
"""
|
||||
def __init__(self, data_path: str, file_path: str, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.file_path = Path(file_path)
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the file exists to be loaded
|
||||
if not self.file_path.exists():
|
||||
raise FileNotFoundError(f'Model not found at {self.file_path}')
|
||||
|
||||
def recall(self) -> dict:
|
||||
# Load the file
|
||||
return torch.load(str(self.file_path), map_location='cpu')
|
||||
|
||||
class WandbLoader(BaseLoader):
|
||||
"""
|
||||
A loader that loads a model from an existing wandb run
|
||||
"""
|
||||
def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.run_path = wandb_run_path
|
||||
self.file_path = wandb_file_path
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||
# Make sure the file can be downloaded
|
||||
if self.wandb.run is not None and self.run_path is None:
|
||||
self.run_path = self.wandb.run.path
|
||||
assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'
|
||||
assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'
|
||||
assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'
|
||||
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
pass # TODO: Actually implement that
|
||||
|
||||
def recall(self) -> dict:
|
||||
file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)
|
||||
return torch.load(file_reference.name, map_location='cpu')
|
||||
|
||||
loader_type_map = {
|
||||
'url': UrlLoader,
|
||||
'local': LocalLoader,
|
||||
'wandb': WandbLoader,
|
||||
}
|
||||
def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
|
||||
if loader_type == 'custom':
|
||||
raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')
|
||||
try:
|
||||
loader_class = loader_type_map[loader_type]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')
|
||||
return loader_class(data_path, **kwargs)
|
||||
|
||||
class BaseSaver:
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
|
||||
save_best_to: Optional[Union[str, bool]] = 'best.pth',
|
||||
save_meta_to: str = './',
|
||||
save_type: str = 'checkpoint',
|
||||
**kwargs
|
||||
):
|
||||
self.data_path = Path(data_path)
|
||||
self.save_latest_to = save_latest_to
|
||||
self.saving_latest = save_latest_to is not None and save_latest_to is not False
|
||||
self.save_best_to = save_best_to
|
||||
self.saving_best = save_best_to is not None and save_best_to is not False
|
||||
self.save_meta_to = save_meta_to
|
||||
self.save_type = save_type
|
||||
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
|
||||
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
|
||||
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:
|
||||
"""
|
||||
Save a general file under save_meta_to
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
class LocalSaver(BaseSaver):
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(data_path, **kwargs)
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the directory exists to be saved to
|
||||
print(f"Saving {self.save_type} locally")
|
||||
if not self.data_path.exists():
|
||||
self.data_path.mkdir(parents=True)
|
||||
|
||||
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
|
||||
# Copy the file to save_path
|
||||
save_path_file_name = Path(save_path).name
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
|
||||
shutil.copy(local_path, save_path)
|
||||
|
||||
class WandbSaver(BaseSaver):
|
||||
def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.run_path = wandb_run_path
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
# Makes sure that the user can upload tot his run
|
||||
if self.run_path is not None:
|
||||
entity, project, run_id = self.run_path.split("/")
|
||||
self.run = self.wandb.init(entity=entity, project=project, id=run_id)
|
||||
else:
|
||||
assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'
|
||||
self.run = self.wandb.run
|
||||
# TODO: Now actually check if upload is possible
|
||||
print(f"Saving to wandb run {self.run.path}-{self.run.name}")
|
||||
|
||||
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
|
||||
# In order to log something in the correct place in wandb, we need to have the same file structure here
|
||||
save_path_file_name = Path(save_path).name
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}")
|
||||
save_path = Path(self.data_path) / save_path
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(local_path, save_path)
|
||||
self.run.save(str(save_path), base_path = str(self.data_path), policy='now')
|
||||
|
||||
class HuggingfaceSaver(BaseSaver):
|
||||
def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.huggingface_repo = huggingface_repo
|
||||
self.token_path = token_path
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs):
|
||||
# Makes sure this user can upload to the repo
|
||||
self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')
|
||||
try:
|
||||
identity = self.hub.whoami() # Errors if not logged in
|
||||
# Then we are logged in
|
||||
except:
|
||||
# We are not logged in. Use the token_path to set the token.
|
||||
if not os.path.exists(self.token_path):
|
||||
raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.")
|
||||
with open(self.token_path, "r") as f:
|
||||
token = f.read().strip()
|
||||
self.hub.HfApi.set_access_token(token)
|
||||
identity = self.hub.whoami()
|
||||
print(f"Saving to huggingface repo {self.huggingface_repo}")
|
||||
|
||||
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
|
||||
# Saving to huggingface is easy, we just need to upload the file with the correct name
|
||||
save_path_file_name = Path(save_path).name
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}")
|
||||
self.hub.upload_file(
|
||||
path_or_fileobj=str(local_path),
|
||||
path_in_repo=str(save_path),
|
||||
repo_id=self.huggingface_repo
|
||||
)
|
||||
|
||||
saver_type_map = {
|
||||
'local': LocalSaver,
|
||||
'wandb': WandbSaver,
|
||||
'huggingface': HuggingfaceSaver
|
||||
}
|
||||
def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:
|
||||
if saver_type == 'custom':
|
||||
raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')
|
||||
try:
|
||||
saver_class = saver_type_map[saver_type]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')
|
||||
return saver_class(data_path, **kwargs)
|
||||
|
||||
|
||||
class Tracker:
|
||||
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
|
||||
self.data_path = Path(data_path)
|
||||
if not dummy_mode:
|
||||
if overwrite_data_path:
|
||||
if self.data_path.exists():
|
||||
shutil.rmtree(self.data_path)
|
||||
self.data_path.mkdir(parents=True)
|
||||
else:
|
||||
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
|
||||
if not self.data_path.exists():
|
||||
self.data_path.mkdir(parents=True)
|
||||
self.logger: BaseLogger = None
|
||||
self.loader: Optional[BaseLoader] = None
|
||||
self.savers: List[BaseSaver]= []
|
||||
self.dummy_mode = dummy_mode
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict):
|
||||
assert self.logger is not None, '`logger` must be set before `init` is called'
|
||||
if self.dummy_mode:
|
||||
# The only thing we need is a loader
|
||||
if self.loader is not None:
|
||||
self.loader.init(self.logger)
|
||||
return
|
||||
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
|
||||
self.logger.init(full_config, extra_config)
|
||||
if self.loader is not None:
|
||||
self.loader.init(self.logger)
|
||||
for saver in self.savers:
|
||||
saver.init(self.logger)
|
||||
|
||||
def add_logger(self, logger: BaseLogger):
|
||||
self.logger = logger
|
||||
|
||||
def add_loader(self, loader: BaseLoader):
|
||||
self.loader = loader
|
||||
|
||||
def add_saver(self, saver: BaseSaver):
|
||||
self.savers.append(saver)
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
self.logger.log(*args, **kwargs)
|
||||
|
||||
def log_images(self, *args, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
self.logger.log_images(*args, **kwargs)
|
||||
|
||||
def log_file(self, *args, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
self.logger.log_file(*args, **kwargs)
|
||||
|
||||
def save_config(self, current_config_path: str, config_name = 'config.json'):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
# Save the config under config_name in the root folder of data_path
|
||||
shutil.copy(current_config_path, self.data_path / config_name)
|
||||
for saver in self.savers:
|
||||
remote_path = Path(saver.save_meta_to) / config_name
|
||||
saver.save_file(current_config_path, str(remote_path))
|
||||
|
||||
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
|
||||
"""
|
||||
Gets the state dict to be saved and writes it to file_path.
|
||||
If save_type is 'checkpoint', we save the entire trainer state dict.
|
||||
If save_type is 'model', we save only the model state dict.
|
||||
"""
|
||||
assert save_type in ['checkpoint', 'model']
|
||||
if save_type == 'checkpoint':
|
||||
trainer.save(file_path, overwrite=True, **kwargs)
|
||||
elif save_type == 'model':
|
||||
if isinstance(trainer, DiffusionPriorTrainer):
|
||||
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||
state_dict = trainer.unwrap_model(prior).state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
elif isinstance(trainer, DecoderTrainer):
|
||||
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||
if trainer.use_ema:
|
||||
trainable_unets = decoder.unets
|
||||
decoder.unets = trainer.unets # Swap EMA unets in
|
||||
state_dict = decoder.state_dict()
|
||||
decoder.unets = trainable_unets # Swap back
|
||||
else:
|
||||
state_dict = decoder.state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
else:
|
||||
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
|
||||
return Path(file_path)
|
||||
|
||||
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
if not is_best and not is_latest:
|
||||
# Nothing to do
|
||||
return
|
||||
# Save the checkpoint and model to data_path
|
||||
checkpoint_path = self.data_path / 'checkpoint.pth'
|
||||
self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)
|
||||
model_path = self.data_path / 'model.pth'
|
||||
self._save_state_dict(trainer, 'model', model_path, **kwargs)
|
||||
print("Saved cached models")
|
||||
# Call the save methods on the savers
|
||||
for saver in self.savers:
|
||||
local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path
|
||||
if saver.saving_latest and is_latest:
|
||||
latest_checkpoint_path = saver.save_latest_to.format(**kwargs)
|
||||
try:
|
||||
saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||
print(f'Error saving checkpoint: {e}')
|
||||
if saver.saving_best and is_best:
|
||||
best_checkpoint_path = saver.save_best_to.format(**kwargs)
|
||||
try:
|
||||
saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||
print(f'Error saving checkpoint: {e}')
|
||||
|
||||
def recall(self):
|
||||
if self.loader is not None:
|
||||
return self.loader.recall()
|
||||
else:
|
||||
raise ValueError('No loader specified')
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from dalle2_pytorch.dalle2_pytorch import (
|
||||
DiffusionPriorNetwork,
|
||||
XClipAdapter
|
||||
)
|
||||
from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver
|
||||
|
||||
# helper functions
|
||||
|
||||
@@ -44,13 +45,66 @@ class TrainSplitConfig(BaseModel):
|
||||
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
||||
return fields
|
||||
|
||||
class TrackerLogConfig(BaseModel):
|
||||
log_type: str = 'console'
|
||||
verbose: bool = False
|
||||
|
||||
class Config:
|
||||
# Each individual log type has it's own arguments that will be passed through the config
|
||||
extra = "allow"
|
||||
|
||||
def create(self, data_path: str):
|
||||
kwargs = self.dict()
|
||||
return create_logger(self.log_type, data_path, **kwargs)
|
||||
|
||||
class TrackerLoadConfig(BaseModel):
|
||||
load_from: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self, data_path: str):
|
||||
kwargs = self.dict()
|
||||
if self.load_from is None:
|
||||
return None
|
||||
return create_loader(self.load_from, data_path, **kwargs)
|
||||
|
||||
class TrackerSaveConfig(BaseModel):
|
||||
save_to: str = 'local'
|
||||
save_all: bool = False
|
||||
save_latest: bool = True
|
||||
save_best: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self, data_path: str):
|
||||
kwargs = self.dict()
|
||||
return create_saver(self.save_to, data_path, **kwargs)
|
||||
|
||||
class TrackerConfig(BaseModel):
|
||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||
data_path: str = './models' # The path where files will be saved locally
|
||||
init_config: Dict[str, Any] = None
|
||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||
wandb_project: str = ''
|
||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||
data_path: str = '.tracker_data'
|
||||
overwrite_data_path: bool = False
|
||||
log: TrackerLogConfig
|
||||
load: Optional[TrackerLoadConfig]
|
||||
save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
|
||||
|
||||
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
|
||||
tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)
|
||||
# Add the logger
|
||||
tracker.add_logger(self.log.create(self.data_path))
|
||||
# Add the loader
|
||||
if self.load is not None:
|
||||
tracker.add_loader(self.load.create(self.data_path))
|
||||
# Add the saver or savers
|
||||
if isinstance(self.save, list):
|
||||
for save_config in self.save:
|
||||
tracker.add_saver(save_config.create(self.data_path))
|
||||
else:
|
||||
tracker.add_saver(self.save.create(self.data_path))
|
||||
# Initialize all the components and verify that all data is valid
|
||||
tracker.init(full_config, extra_config)
|
||||
return tracker
|
||||
|
||||
# diffusion prior pydantic classes
|
||||
|
||||
@@ -162,6 +216,7 @@ class UnetConfig(BaseModel):
|
||||
cond_on_text_encodings: bool = None
|
||||
cond_dim: int = None
|
||||
channels: int = 3
|
||||
self_attn: ListOrTuple(int)
|
||||
attn_dim_head: int = 32
|
||||
attn_heads: int = 16
|
||||
|
||||
@@ -238,6 +293,8 @@ class DecoderTrainConfig(BaseModel):
|
||||
epochs: int = 20
|
||||
lr: SingularOrIterable(float) = 1e-4
|
||||
wd: SingularOrIterable(float) = 0.01
|
||||
warmup_steps: Optional[SingularOrIterable(int)] = None
|
||||
find_unused_parameters: bool = True
|
||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||
save_every_n_samples: int = 100000
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
@@ -247,9 +304,6 @@ class DecoderTrainConfig(BaseModel):
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.999
|
||||
amp: bool = False
|
||||
save_all: bool = False # Whether to preserve all checkpoints
|
||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||
save_best: bool = True # Whether to save the best checkpoint
|
||||
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||
|
||||
class DecoderEvaluateConfig(BaseModel):
|
||||
@@ -271,7 +325,6 @@ class TrainDecoderConfig(BaseModel):
|
||||
train: DecoderTrainConfig
|
||||
evaluate: DecoderEvaluateConfig
|
||||
tracker: TrackerConfig
|
||||
load: DecoderLoadConfig
|
||||
seed: int = 0
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -3,10 +3,13 @@ import copy
|
||||
from pathlib import Path
|
||||
from math import ceil
|
||||
from functools import partial, wraps
|
||||
from contextlib import nullcontext
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
@@ -14,6 +17,8 @@ from dalle2_pytorch.optimizer import get_optimizer
|
||||
from dalle2_pytorch.version import __version__
|
||||
from packaging import version
|
||||
|
||||
import pytorch_warmup as warmup
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
from accelerate import Accelerator
|
||||
@@ -162,19 +167,32 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
group_wd_params = True,
|
||||
device = None,
|
||||
accelerator = None,
|
||||
verbose = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
|
||||
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
# verbosity
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
# assign some helpful member vars
|
||||
|
||||
self.accelerator = accelerator
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
self.text_conditioned = diffusion_prior.condition_on_text_encodings
|
||||
|
||||
# setting the device
|
||||
|
||||
if not exists(accelerator) and not exists(device):
|
||||
diffusion_prior_device = next(diffusion_prior.parameters()).device
|
||||
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
|
||||
self.device = diffusion_prior_device
|
||||
else:
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
|
||||
# save model
|
||||
|
||||
self.diffusion_prior = diffusion_prior
|
||||
@@ -210,11 +228,14 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# track steps internally
|
||||
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
self.register_buffer('step', torch.tensor([0], device = self.device))
|
||||
|
||||
# accelerator wrappers
|
||||
|
||||
def print(self, msg):
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.print(msg)
|
||||
else:
|
||||
@@ -428,6 +449,7 @@ class DecoderTrainer(nn.Module):
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
warmup_steps = None,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
@@ -449,13 +471,15 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
||||
|
||||
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
optimizers = []
|
||||
schedulers = []
|
||||
warmup_schedulers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
@@ -467,6 +491,13 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
optimizers.append(optimizer)
|
||||
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
|
||||
schedulers.append(scheduler)
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||
|
||||
@@ -474,15 +505,27 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
schedulers = list(self.accelerator.prepare(*schedulers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
# store optimizers
|
||||
|
||||
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
||||
setattr(self, f'optim{opt_ind}', optimizer)
|
||||
|
||||
# store schedulers
|
||||
|
||||
for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):
|
||||
setattr(self, f'sched{sched_ind}', scheduler)
|
||||
|
||||
# store warmup schedulers
|
||||
|
||||
self.warmup_schedulers = warmup_schedulers
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
@@ -491,7 +534,7 @@ class DecoderTrainer(nn.Module):
|
||||
save_obj = dict(
|
||||
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
|
||||
version = __version__,
|
||||
step = self.step.item(),
|
||||
steps = self.steps.cpu(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -505,30 +548,38 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.accelerator.save(save_obj, str(path))
|
||||
|
||||
def load_state_dict(self, loaded_obj, only_model = False, strict = True):
|
||||
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
||||
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||
|
||||
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.steps.copy_(loaded_obj['steps'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):
|
||||
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
warmup_scheduler = self.warmup_schedulers[ind]
|
||||
|
||||
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if exists(warmup_scheduler):
|
||||
warmup_scheduler.last_step = last_step
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
def load(self, path, only_model = False, strict = True):
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path), map_location = 'cpu')
|
||||
|
||||
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
||||
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||
|
||||
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
@@ -536,6 +587,12 @@ class DecoderTrainer(nn.Module):
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
def increment_step(self, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
|
||||
unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)
|
||||
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
|
||||
|
||||
def update(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
@@ -544,17 +601,25 @@ class DecoderTrainer(nn.Module):
|
||||
index = unet_number - 1
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
scheduler = getattr(self, f'sched{index}')
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
warmup_scheduler = self.warmup_schedulers[index]
|
||||
scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext
|
||||
|
||||
with scheduler_context():
|
||||
scheduler.step()
|
||||
|
||||
if self.use_ema:
|
||||
ema_unet = self.ema_unets[index]
|
||||
ema_unet.update()
|
||||
|
||||
self.step += 1
|
||||
self.increment_step(unet_number)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@@ -604,7 +669,6 @@ class DecoderTrainer(nn.Module):
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
# with autocast(enabled = self.amp):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import time
|
||||
import importlib
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# time helpers
|
||||
|
||||
class Timer:
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.14.1'
|
||||
__version__ = '0.16.15'
|
||||
|
||||
183
prior.md
Normal file
183
prior.md
Normal file
@@ -0,0 +1,183 @@
|
||||
# Diffusion Prior
|
||||
This readme serves as an introduction to the diffusion prior.
|
||||
|
||||
## Intro
|
||||
|
||||
A properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way—then ability the translate between them could extremely helpful.
|
||||
|
||||
### Motivation
|
||||
|
||||
Before we dive into the model, let’s look at a quick example of where the model may be helpful.
|
||||
|
||||
For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder.
|
||||
|
||||
> [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets.
|
||||
|
||||
```python
|
||||
# Load Models
|
||||
clip_model = clip.load("ViT-L/14")
|
||||
decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings
|
||||
|
||||
# Retrieve prompt from user and encode with CLIP
|
||||
prompt = "A corgi wearing sunglasses"
|
||||
tokenized_text = tokenize(prompt)
|
||||
text_embedding = clip_model.encode_text(tokenized_text)
|
||||
|
||||
# Now, pass the text embedding to the decoder
|
||||
predicted_image = decoder.sample(text_embedding)
|
||||
```
|
||||
|
||||
> **Question**: *Can you spot the issue here?*
|
||||
>
|
||||
> **Answer**: *We’re trying to generate an image from a text embedding!*
|
||||
|
||||
Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution
|
||||
|
||||
```python
|
||||
# Load Models
|
||||
prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb
|
||||
decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings
|
||||
|
||||
# Retrieve prompt from user and encode with a prior
|
||||
prompt = "A corgi wearing sunglasses"
|
||||
tokenized_text = tokenize(prompt)
|
||||
text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images!
|
||||
|
||||
# Now, pass the predicted image embedding to the decoder
|
||||
predicted_image = decoder.sample(text_embedding)
|
||||
```
|
||||
|
||||
With the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data.
|
||||
|
||||
> **You may be asking yourself the following question:**
|
||||
>
|
||||
> *"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"*
|
||||
>
|
||||
> OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *"it doesn't work as well as decoders trained on image embeddings"*...also...its just an example :smile:
|
||||
|
||||
## Usage
|
||||
|
||||
To utilize a pre-trained prior, it’s quite simple.
|
||||
|
||||
### Loading Checkpoints
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer
|
||||
|
||||
def load_diffusion_model(dprior_path):
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim=768,
|
||||
depth=24,
|
||||
dim_head=64,
|
||||
heads=32,
|
||||
normformer=True,
|
||||
attn_dropout=5e-2,
|
||||
ff_dropout=5e-2,
|
||||
num_time_embeds=1,
|
||||
num_image_embeds=1,
|
||||
num_text_embeds=1,
|
||||
num_timesteps=1000,
|
||||
ff_mult=4
|
||||
)
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net=prior_network,
|
||||
clip=OpenAIClipAdapter("ViT-L/14"),
|
||||
image_embed_dim=768,
|
||||
timesteps=1000,
|
||||
cond_drop_prob=0.1,
|
||||
loss_type="l2",
|
||||
condition_on_text_encodings=True,
|
||||
|
||||
)
|
||||
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior=diffusion_prior,
|
||||
lr=1.1e-4,
|
||||
wd=6.02e-2,
|
||||
max_grad_norm=0.5,
|
||||
amp=False,
|
||||
group_wd_params=True,
|
||||
use_ema=True,
|
||||
device=device,
|
||||
accelerator=None,
|
||||
)
|
||||
|
||||
trainer.load(dprior_path)
|
||||
|
||||
return trainer
|
||||
```
|
||||
|
||||
Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*)
|
||||
|
||||
### Sampling
|
||||
Once we have a pre-trained model, generating embeddings is quite simple!
|
||||
```python
|
||||
# tokenize the text
|
||||
tokenized_text = clip.tokenize("<your amazing prompt>")
|
||||
# predict an embedding
|
||||
predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)
|
||||
```
|
||||
|
||||
The resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768).
|
||||
|
||||
> For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text().
|
||||
|
||||
**Some things to note:**
|
||||
* It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt.
|
||||
* You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*.
|
||||
|
||||
---
|
||||
|
||||
## Training
|
||||
|
||||
### Overview
|
||||
|
||||
Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration
|
||||
|
||||
## Dataset
|
||||
|
||||
To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader.
|
||||
|
||||
## Configuration
|
||||
|
||||
The configuration file allows for you to easily track and reproduce experiments. It is a simple JSON file that will specify the architecture, dataset, and training parameters. For more information and specifics please see the configuration README.
|
||||
|
||||
## Distributed Training
|
||||
|
||||
If you would like to train in a distributed manner we have opted to leverage huggingface’ new Accelerate library. HFA makes it extremely simple to distribute work across multiple GPU’s and nodes. All that is required of you is to follow the simple CLI configuration tool [more information here](https://huggingface.co/docs/accelerate/accelerator).
|
||||
|
||||
## Evaluation
|
||||
|
||||
There are a variety of metrics available to you when training the prior. You can read a brief description of each in the table below:
|
||||
| Metric | Description | Comments |
|
||||
| ----------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Online Model Validation | The validation loss associated with your online model. | Ideally validation loss will be as low as possible. Using L2 loss, values as low as `0.1` and lower are possible after around 1 Billion samples seen. |
|
||||
| EMA Validation | This metric measures the validation loss associated with your EMA model. | This will likely lag behind your "online" model's validation loss, but should outperform in the long-term. |
|
||||
| Baseline Similarity | Baseline similarity refers to the similarity between your dataset's prompts and associated image embeddings. This will serve as a guide for your prior's performance in cosine similarity. | Generally `0.3` is considered a good cosine similarity for caption similarity. |
|
||||
| Similarity With Original Image | This metric will measure the cosine similarity between your prior's predicted image embedding and the actual image that the caption was associated with. This is useful for determining wether your prior is generating images with the right contents. | Values around `0.75`+ are obtainable. This metric should improve rapidly in the early stages of training and plateau with diminishing increases over time. If it takes hundreds of millions of samples to reach above `0.5`/`0.6` similarity--then you likely are suffering from some kind of training error or inefficiency (i.e. not using EMA) |
|
||||
| Difference From Baseline Similarity | Sometimes its useful to visualize a metric in another light. This metric will show you how your prior's predicted image embeddings match up with the baseline similarity measured in your dataset. | This value should float around `0.0` with some room for variation. After a billion samples seen, values are within `0.01`+/- of `0.0`. If this climbs to high, (~>`0.02`) then this may be a sign that your model is overfitting somehow. |
|
||||
| Similarity With Text | This metric is your bread and butter cosine similarity between the predicted image embedding and the original caption given to the prior. Monitoring this metric will be on of your main focuses and is probably the second most important behind your loss. | As mentioned, this value should be close to baseline similarity. We have observed early rapid increase with diminishing returns as the prior learns to generate valid image embeddings. If this value increases too far beyond the baseline similarity--it could be an indication that your model is overfitting. |
|
||||
| Similarity With Unrelated Caption | This metric will attempt to exposed an overfit prior by feeding it arbitrary prompts (from your dataset) and then measure the similarity of this predicted embedding with some other image. | Early on we found that a poorly trained/modeled prior could effectively fool CLIP into believing that the cosine similarity between two images were high (when in fact the caption and image were completely unrelated). With this in mind--a low value is ideal, anything below `0.1` is probably safe. |
|
||||
|
||||
## Launching the script
|
||||
|
||||
Now that you’ve done all the prep it’s time for the easy part! 🚀
|
||||
|
||||
To actually launch the script, you will either use `accelerate launch train_diffusion_prior.py --config_path <path to your config>` to launch with distributed training & huggingface accelerate or `python train_diffusion_prior.py` if you would like to train on your gpu/cpu without huggingface accelerate.
|
||||
|
||||
## Checkpointing
|
||||
|
||||
Checkpoints will be saved to the directory specified in your configuration file.
|
||||
|
||||
Additionally, a final checkpoint is saved before running the test split. This file will be saved to the same directory and titled “latest.pth”. This is to avoid problems where your `save_every` configuration does not overlap with the number of steps required to do a complete pass through the data.
|
||||
|
||||
## Things To Keep In Mind
|
||||
|
||||
The prior has not been trained for tasks other than the traditional CLIP embedding translation…at least yet.
|
||||
|
||||
As we finalize the replication of unCLIP, there will almost assuredly be experiments attempting to apply the prior network to other tasks.
|
||||
|
||||
With that in mind, you are more or less a pioneer in embedding-translation if you are reading this and attempting something you don’t see documentation for!
|
||||
1
setup.py
1
setup.py
@@ -37,6 +37,7 @@ setup(
|
||||
'packaging',
|
||||
'pillow',
|
||||
'pydantic',
|
||||
'pytorch-warmup',
|
||||
'resize-right>=0.0.2',
|
||||
'rotary-embedding-torch',
|
||||
'torch>=1.10',
|
||||
|
||||
106
train_decoder.py
106
train_decoder.py
@@ -1,11 +1,12 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from dalle2_pytorch.trainer import DecoderTrainer
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.trackers import Tracker
|
||||
from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to
|
||||
from clip import tokenize
|
||||
|
||||
import torchvision
|
||||
@@ -239,42 +240,33 @@ def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=Fa
|
||||
metrics[metric_name] = metrics_tensor[i].item()
|
||||
return metrics
|
||||
|
||||
def save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, relative_paths):
|
||||
def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):
|
||||
"""
|
||||
Logs the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
if isinstance(relative_paths, str):
|
||||
relative_paths = [relative_paths]
|
||||
for relative_path in relative_paths:
|
||||
local_path = str(tracker.data_path / relative_path)
|
||||
trainer.save(local_path, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses)
|
||||
tracker.save_file(local_path)
|
||||
tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)
|
||||
|
||||
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
||||
def recall_trainer(tracker: Tracker, trainer: DecoderTrainer):
|
||||
"""
|
||||
Loads the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
trainer.accelerator.print(print_ribbon(f"Loading model from {recall_source}"))
|
||||
local_filepath = tracker.recall_file(recall_source, **load_config)
|
||||
state_dict = trainer.load(local_filepath)
|
||||
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0)
|
||||
trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}"))
|
||||
state_dict = tracker.recall()
|
||||
trainer.load_state_dict(state_dict, only_model=False, strict=True)
|
||||
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0)
|
||||
|
||||
def train(
|
||||
dataloaders,
|
||||
decoder,
|
||||
accelerator,
|
||||
tracker,
|
||||
decoder: Decoder,
|
||||
accelerator: Accelerator,
|
||||
tracker: Tracker,
|
||||
inference_device,
|
||||
load_config=None,
|
||||
evaluate_config=None,
|
||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||
validation_samples = None,
|
||||
epochs = 20,
|
||||
n_sample_images = 5,
|
||||
save_every_n_samples = 100000,
|
||||
save_all=False,
|
||||
save_latest=True,
|
||||
save_best=True,
|
||||
unet_training_mask=None,
|
||||
condition_on_text_encodings=False,
|
||||
**kwargs
|
||||
@@ -299,13 +291,13 @@ def train(
|
||||
val_sample = 0
|
||||
step = lambda: int(trainer.step.item())
|
||||
|
||||
if exists(load_config) and exists(load_config.source):
|
||||
start_epoch, validation_losses, next_task, recalled_sample = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config.dict())
|
||||
if tracker.loader is not None:
|
||||
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||
if next_task == 'train':
|
||||
sample = recalled_sample
|
||||
if next_task == 'val':
|
||||
val_sample = recalled_sample
|
||||
accelerator.print(f"Loaded model from {load_config.source} on epoch {start_epoch} with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
||||
accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
||||
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
|
||||
trainer.to(device=inference_device)
|
||||
|
||||
@@ -399,19 +391,14 @@ def train(
|
||||
}
|
||||
|
||||
if is_master:
|
||||
tracker.log(log_data, step=step(), verbose=True)
|
||||
tracker.log(log_data, step=step())
|
||||
|
||||
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
||||
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
|
||||
print("Saving snapshot")
|
||||
last_snapshot = sample
|
||||
# We need to know where the model should be saved
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
if save_all:
|
||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step()}.pth")
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||
@@ -486,7 +473,7 @@ def train(
|
||||
if is_master:
|
||||
unet_average_val_loss = all_average_val_losses.mean(dim=0)
|
||||
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
|
||||
tracker.log(val_loss_map, step=step(), verbose=True)
|
||||
tracker.log(val_loss_map, step=step())
|
||||
next_task = 'eval'
|
||||
|
||||
if next_task == 'eval':
|
||||
@@ -494,7 +481,7 @@ def train(
|
||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
||||
if is_master:
|
||||
tracker.log(evaluation, step=step(), verbose=True)
|
||||
tracker.log(evaluation, step=step())
|
||||
next_task = 'sample'
|
||||
val_sample = 0
|
||||
|
||||
@@ -509,22 +496,16 @@ def train(
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||
# Get the same paths
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
is_best = False
|
||||
if all_average_val_losses is not None:
|
||||
average_loss = all_average_val_losses.mean(dim=0).item()
|
||||
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
||||
save_paths.append("best.pth")
|
||||
if len(validation_losses) == 0 or average_loss < min(validation_losses):
|
||||
is_best = True
|
||||
validation_losses.append(average_loss)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)
|
||||
next_task = 'train'
|
||||
|
||||
def create_tracker(accelerator, config, config_path, tracker_type=None, data_path=None):
|
||||
"""
|
||||
Creates a tracker of the specified type and initializes special features based on the full config
|
||||
"""
|
||||
def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:
|
||||
tracker_config = config.tracker
|
||||
accelerator_config = {
|
||||
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
|
||||
@@ -532,40 +513,16 @@ def create_tracker(accelerator, config, config_path, tracker_type=None, data_pat
|
||||
"NumProcesses": accelerator.num_processes,
|
||||
"MixedPrecision": accelerator.mixed_precision
|
||||
}
|
||||
init_config = { "config": {**config.dict(), **accelerator_config} }
|
||||
data_path = data_path or tracker_config.data_path
|
||||
tracker_type = tracker_type or tracker_config.tracker_type
|
||||
|
||||
if tracker_type == "dummy":
|
||||
tracker = DummyTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
elif tracker_type == "console":
|
||||
tracker = ConsoleTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
elif tracker_type == "wandb":
|
||||
# We need to initialize the resume state here
|
||||
load_config = config.load
|
||||
if load_config.source == "wandb" and load_config.resume:
|
||||
# Then we are resuming the run load_config["run_path"]
|
||||
run_id = load_config.run_path.split("/")[-1]
|
||||
init_config["id"] = run_id
|
||||
init_config["resume"] = "must"
|
||||
|
||||
init_config["entity"] = tracker_config.wandb_entity
|
||||
init_config["project"] = tracker_config.wandb_project
|
||||
tracker = WandbTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
tracker.save_file(str(config_path.absolute()), str(config_path.parent.absolute()))
|
||||
else:
|
||||
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
|
||||
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
||||
tracker.save_config(config_path, config_name='decoder_config.json')
|
||||
return tracker
|
||||
|
||||
def initialize_training(config, config_path):
|
||||
def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
# Make sure if we are not loading, distributed models are initialized to the same values
|
||||
torch.manual_seed(config.seed)
|
||||
|
||||
# Set up accelerator for configurable distributed training
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
# Set up data
|
||||
@@ -592,7 +549,7 @@ def initialize_training(config, config_path):
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
|
||||
# Create and initialize the tracker if we are the master
|
||||
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
|
||||
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
|
||||
|
||||
has_img_embeddings = config.data.img_embeddings_url is not None
|
||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||
@@ -622,7 +579,6 @@ def initialize_training(config, config_path):
|
||||
train(dataloaders, decoder, accelerator,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
load_config=config.load,
|
||||
evaluate_config=config.evaluate,
|
||||
condition_on_text_encodings=conditioning_on_text,
|
||||
**config.train.dict(),
|
||||
|
||||
Reference in New Issue
Block a user