mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 08:14:21 +01:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80046334ad | ||
|
|
36fb46a95e | ||
|
|
07abfcf45b | ||
|
|
2e35a9967d | ||
|
|
406e75043f | ||
|
|
9646dfc0e6 | ||
|
|
62043acb2f | ||
|
|
417ff808e6 | ||
|
|
f3d7e226ba | ||
|
|
48a1302428 | ||
|
|
ccaa46b81b | ||
|
|
76d08498cc | ||
|
|
f9423d308b |
18
README.md
18
README.md
@@ -371,6 +371,7 @@ loss.backward()
|
|||||||
unet1 = Unet(
|
unet1 = Unet(
|
||||||
dim = 128,
|
dim = 128,
|
||||||
image_embed_dim = 512,
|
image_embed_dim = 512,
|
||||||
|
text_embed_dim = 512,
|
||||||
cond_dim = 128,
|
cond_dim = 128,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
dim_mults=(1, 2, 4, 8),
|
dim_mults=(1, 2, 4, 8),
|
||||||
@@ -395,7 +396,7 @@ decoder = Decoder(
|
|||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
for unet_number in (1, 2):
|
for unet_number in (1, 2):
|
||||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# do above for many steps
|
# do above for many steps
|
||||||
@@ -860,25 +861,23 @@ unet1 = Unet(
|
|||||||
text_embed_dim = 512,
|
text_embed_dim = 512,
|
||||||
cond_dim = 128,
|
cond_dim = 128,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
dim_mults=(1, 2, 4, 8)
|
dim_mults=(1, 2, 4, 8),
|
||||||
|
cond_on_text_encodings = True,
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
unet2 = Unet(
|
unet2 = Unet(
|
||||||
dim = 16,
|
dim = 16,
|
||||||
image_embed_dim = 512,
|
image_embed_dim = 512,
|
||||||
text_embed_dim = 512,
|
|
||||||
cond_dim = 128,
|
cond_dim = 128,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
dim_mults = (1, 2, 4, 8, 16),
|
dim_mults = (1, 2, 4, 8, 16),
|
||||||
cond_on_text_encodings = True
|
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
unet = (unet1, unet2),
|
unet = (unet1, unet2),
|
||||||
image_sizes = (128, 256),
|
image_sizes = (128, 256),
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 1000,
|
timesteps = 1000
|
||||||
condition_on_text_encodings = True
|
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
decoder_trainer = DecoderTrainer(
|
decoder_trainer = DecoderTrainer(
|
||||||
@@ -903,8 +902,8 @@ for unet_number in (1, 2):
|
|||||||
# after much training
|
# after much training
|
||||||
# you can sample from the exponentially moving averaged unets as so
|
# you can sample from the exponentially moving averaged unets as so
|
||||||
|
|
||||||
mock_image_embed = torch.randn(4, 512).cuda()
|
mock_image_embed = torch.randn(32, 512).cuda()
|
||||||
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Diffusion Prior Training
|
### Diffusion Prior Training
|
||||||
@@ -1112,7 +1111,8 @@ For detailed information on training the diffusion prior, please refer to the [d
|
|||||||
- [x] allow for unet to be able to condition non-cross attention style as well
|
- [x] allow for unet to be able to condition non-cross attention style as well
|
||||||
- [x] speed up inference, read up on papers (ddim)
|
- [x] speed up inference, read up on papers (ddim)
|
||||||
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||||
- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow
|
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
|
||||||
|
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
|
||||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|||||||
@@ -1,18 +1,14 @@
|
|||||||
{
|
{
|
||||||
"prior": {
|
"prior": {
|
||||||
"clip": {
|
"clip": {
|
||||||
"make": "x-clip",
|
"make": "openai",
|
||||||
"model": "ViT-L/14",
|
"model": "ViT-L/14"
|
||||||
"base_model_kwargs": {
|
|
||||||
"dim_text": 768,
|
|
||||||
"dim_image": 768,
|
|
||||||
"dim_latent": 768
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"net": {
|
"net": {
|
||||||
"dim": 768,
|
"dim": 768,
|
||||||
"depth": 12,
|
"depth": 12,
|
||||||
"num_timesteps": 1000,
|
"num_timesteps": 1000,
|
||||||
|
"max_text_len": 77,
|
||||||
"num_time_embeds": 1,
|
"num_time_embeds": 1,
|
||||||
"num_image_embeds": 1,
|
"num_image_embeds": 1,
|
||||||
"num_text_embeds": 1,
|
"num_text_embeds": 1,
|
||||||
@@ -20,8 +16,8 @@
|
|||||||
"heads": 12,
|
"heads": 12,
|
||||||
"ff_mult": 4,
|
"ff_mult": 4,
|
||||||
"norm_out": true,
|
"norm_out": true,
|
||||||
"attn_dropout": 0.0,
|
"attn_dropout": 0.05,
|
||||||
"ff_dropout": 0.0,
|
"ff_dropout": 0.05,
|
||||||
"final_proj": true,
|
"final_proj": true,
|
||||||
"normformer": true,
|
"normformer": true,
|
||||||
"rotary_emb": true
|
"rotary_emb": true
|
||||||
@@ -30,6 +26,7 @@
|
|||||||
"image_size": 224,
|
"image_size": 224,
|
||||||
"image_channels": 3,
|
"image_channels": 3,
|
||||||
"timesteps": 1000,
|
"timesteps": 1000,
|
||||||
|
"sample_timesteps": 64,
|
||||||
"cond_drop_prob": 0.1,
|
"cond_drop_prob": 0.1,
|
||||||
"loss_type": "l2",
|
"loss_type": "l2",
|
||||||
"predict_x_start": true,
|
"predict_x_start": true,
|
||||||
@@ -37,34 +34,48 @@
|
|||||||
"condition_on_text_encodings": true
|
"condition_on_text_encodings": true
|
||||||
},
|
},
|
||||||
"data": {
|
"data": {
|
||||||
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
|
"batch_size": 128,
|
||||||
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
|
"num_data_points": 100000,
|
||||||
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
|
"eval_every_seconds": 1600,
|
||||||
"batch_size": 256,
|
"image_url": "<path to your images>",
|
||||||
|
"meta_url": "<path to your metadata>",
|
||||||
"splits": {
|
"splits": {
|
||||||
"train": 0.9,
|
"train": 0.8,
|
||||||
"val": 1e-7,
|
"val": 0.1,
|
||||||
"test": 0.0999999
|
"test": 0.1
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"train": {
|
"train": {
|
||||||
"epochs": 1,
|
"epochs": 5,
|
||||||
"lr": 1.1e-4,
|
"lr": 1.1e-4,
|
||||||
"wd": 6.02e-2,
|
"wd": 6.02e-2,
|
||||||
"max_grad_norm": 0.5,
|
"max_grad_norm": 0.5,
|
||||||
"use_ema": true,
|
"use_ema": true,
|
||||||
|
"ema_beta": 0.9999,
|
||||||
|
"ema_update_after_step": 50,
|
||||||
|
"warmup_steps": 50,
|
||||||
"amp": false,
|
"amp": false,
|
||||||
"save_every": 10000
|
"save_every_seconds": 3600,
|
||||||
},
|
"eval_timesteps": [64, 1000],
|
||||||
"load": {
|
"random_seed": 84513
|
||||||
"source": null,
|
|
||||||
"resume": false
|
|
||||||
},
|
},
|
||||||
"tracker": {
|
"tracker": {
|
||||||
"tracker_type": "wandb",
|
"data_path": ".prior",
|
||||||
"data_path": "./prior_checkpoints",
|
"overwrite_data_path": true,
|
||||||
"wandb_entity": "laion",
|
"log": {
|
||||||
"wandb_project": "diffusion-prior",
|
"log_type": "wandb",
|
||||||
"verbose": true
|
"wandb_entity": "<your wandb username>",
|
||||||
|
"wandb_project": "prior_debugging",
|
||||||
|
"wandb_resume": false,
|
||||||
|
"verbose": true
|
||||||
|
},
|
||||||
|
"save": [
|
||||||
|
{
|
||||||
|
"save_to": "local",
|
||||||
|
"save_type": "checkpoint",
|
||||||
|
"save_latest_to": ".prior/latest_checkpoint.pth",
|
||||||
|
"save_best_to": ".prior/best_checkpoint.pth"
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -516,6 +516,17 @@ class NoiseScheduler(nn.Module):
|
|||||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
|
||||||
|
shape = x_from.shape
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x_from))
|
||||||
|
|
||||||
|
alpha = extract(self.sqrt_alphas_cumprod, from_t, shape)
|
||||||
|
sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape)
|
||||||
|
alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape)
|
||||||
|
sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape)
|
||||||
|
|
||||||
|
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
|
||||||
|
|
||||||
def predict_start_from_noise(self, x_t, t, noise):
|
def predict_start_from_noise(self, x_t, t, noise):
|
||||||
return (
|
return (
|
||||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
||||||
@@ -1492,6 +1503,7 @@ class LinearAttention(nn.Module):
|
|||||||
k = k.softmax(dim = -2)
|
k = k.softmax(dim = -2)
|
||||||
|
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
v = v / (x * y)
|
||||||
|
|
||||||
context = einsum('b n d, b n e -> b d e', k, v)
|
context = einsum('b n d, b n e -> b d e', k, v)
|
||||||
out = einsum('b n d, b d e -> b n e', q, context)
|
out = einsum('b n d, b d e -> b n e', q, context)
|
||||||
@@ -1527,6 +1539,38 @@ class CrossEmbedLayer(nn.Module):
|
|||||||
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
||||||
return torch.cat(fmaps, dim = 1)
|
return torch.cat(fmaps, dim = 1)
|
||||||
|
|
||||||
|
class UpsampleCombiner(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
*,
|
||||||
|
enabled = False,
|
||||||
|
dim_ins = tuple(),
|
||||||
|
dim_outs = tuple()
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert len(dim_ins) == len(dim_outs)
|
||||||
|
self.enabled = enabled
|
||||||
|
|
||||||
|
if not self.enabled:
|
||||||
|
self.dim_out = dim
|
||||||
|
return
|
||||||
|
|
||||||
|
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
|
||||||
|
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
|
||||||
|
|
||||||
|
def forward(self, x, fmaps = None):
|
||||||
|
target_size = x.shape[-1]
|
||||||
|
|
||||||
|
fmaps = default(fmaps, tuple())
|
||||||
|
|
||||||
|
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
|
||||||
|
return x
|
||||||
|
|
||||||
|
fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
|
||||||
|
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
|
||||||
|
return torch.cat((x, *outs), dim = 1)
|
||||||
|
|
||||||
class Unet(nn.Module):
|
class Unet(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1564,6 +1608,7 @@ class Unet(nn.Module):
|
|||||||
scale_skip_connection = False,
|
scale_skip_connection = False,
|
||||||
pixel_shuffle_upsample = True,
|
pixel_shuffle_upsample = True,
|
||||||
final_conv_kernel_size = 1,
|
final_conv_kernel_size = 1,
|
||||||
|
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1699,7 +1744,8 @@ class Unet(nn.Module):
|
|||||||
self.ups = nn.ModuleList([])
|
self.ups = nn.ModuleList([])
|
||||||
num_resolutions = len(in_out)
|
num_resolutions = len(in_out)
|
||||||
|
|
||||||
skip_connect_dims = [] # keeping track of skip connection dimensions
|
skip_connect_dims = [] # keeping track of skip connection dimensions
|
||||||
|
upsample_combiner_dims = [] # keeping track of dimensions for final upsample feature map combiner
|
||||||
|
|
||||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
|
||||||
is_first = ind == 0
|
is_first = ind == 0
|
||||||
@@ -1741,6 +1787,8 @@ class Unet(nn.Module):
|
|||||||
elif sparse_attn:
|
elif sparse_attn:
|
||||||
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
|
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
|
||||||
|
|
||||||
|
upsample_combiner_dims.append(dim_out)
|
||||||
|
|
||||||
self.ups.append(nn.ModuleList([
|
self.ups.append(nn.ModuleList([
|
||||||
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||||
@@ -1748,7 +1796,18 @@ class Unet(nn.Module):
|
|||||||
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||||
]))
|
]))
|
||||||
|
|
||||||
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
# whether to combine outputs from all upsample blocks for final resnet block
|
||||||
|
|
||||||
|
self.upsample_combiner = UpsampleCombiner(
|
||||||
|
dim = dim,
|
||||||
|
enabled = combine_upsample_fmaps,
|
||||||
|
dim_ins = upsample_combiner_dims,
|
||||||
|
dim_outs = (dim,) * len(upsample_combiner_dims)
|
||||||
|
)
|
||||||
|
|
||||||
|
# a final resnet block
|
||||||
|
|
||||||
|
self.final_resnet_block = ResnetBlock(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
||||||
|
|
||||||
out_dim_in = dim + (channels if lowres_cond else 0)
|
out_dim_in = dim + (channels if lowres_cond else 0)
|
||||||
|
|
||||||
@@ -1772,7 +1831,7 @@ class Unet(nn.Module):
|
|||||||
channels == self.channels and \
|
channels == self.channels and \
|
||||||
cond_on_image_embeds == self.cond_on_image_embeds and \
|
cond_on_image_embeds == self.cond_on_image_embeds and \
|
||||||
cond_on_text_encodings == self.cond_on_text_encodings and \
|
cond_on_text_encodings == self.cond_on_text_encodings and \
|
||||||
cond_on_lowres_noise == self.cond_on_lowres_noise and \
|
lowres_noise_cond == self.lowres_noise_cond and \
|
||||||
channels_out == self.channels_out:
|
channels_out == self.channels_out:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -1942,7 +2001,8 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# go through the layers of the unet, down and up
|
# go through the layers of the unet, down and up
|
||||||
|
|
||||||
hiddens = []
|
down_hiddens = []
|
||||||
|
up_hiddens = []
|
||||||
|
|
||||||
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
|
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
|
||||||
if exists(pre_downsample):
|
if exists(pre_downsample):
|
||||||
@@ -1952,10 +2012,10 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
for resnet_block in resnet_blocks:
|
for resnet_block in resnet_blocks:
|
||||||
x = resnet_block(x, t, c)
|
x = resnet_block(x, t, c)
|
||||||
hiddens.append(x)
|
down_hiddens.append(x.contiguous())
|
||||||
|
|
||||||
x = attn(x)
|
x = attn(x)
|
||||||
hiddens.append(x.contiguous())
|
down_hiddens.append(x.contiguous())
|
||||||
|
|
||||||
if exists(post_downsample):
|
if exists(post_downsample):
|
||||||
x = post_downsample(x)
|
x = post_downsample(x)
|
||||||
@@ -1967,7 +2027,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
x = self.mid_block2(x, t, mid_c)
|
x = self.mid_block2(x, t, mid_c)
|
||||||
|
|
||||||
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
|
||||||
|
|
||||||
for init_block, resnet_blocks, attn, upsample in self.ups:
|
for init_block, resnet_blocks, attn, upsample in self.ups:
|
||||||
x = connect_skip(x)
|
x = connect_skip(x)
|
||||||
@@ -1978,8 +2038,12 @@ class Unet(nn.Module):
|
|||||||
x = resnet_block(x, t, c)
|
x = resnet_block(x, t, c)
|
||||||
|
|
||||||
x = attn(x)
|
x = attn(x)
|
||||||
|
|
||||||
|
up_hiddens.append(x.contiguous())
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
|
x = self.upsample_combiner(x, up_hiddens)
|
||||||
|
|
||||||
x = torch.cat((x, r), dim = 1)
|
x = torch.cat((x, r), dim = 1)
|
||||||
|
|
||||||
x = self.final_resnet_block(x, t)
|
x = self.final_resnet_block(x, t)
|
||||||
@@ -2432,14 +2496,18 @@ class Decoder(nn.Module):
|
|||||||
is_latent_diffusion = False,
|
is_latent_diffusion = False,
|
||||||
lowres_noise_level = None,
|
lowres_noise_level = None,
|
||||||
inpaint_image = None,
|
inpaint_image = None,
|
||||||
inpaint_mask = None
|
inpaint_mask = None,
|
||||||
|
inpaint_resample_times = 5
|
||||||
):
|
):
|
||||||
device = self.device
|
device = self.device
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
img = torch.randn(shape, device = device)
|
img = torch.randn(shape, device = device)
|
||||||
|
|
||||||
if exists(inpaint_image):
|
is_inpaint = exists(inpaint_image)
|
||||||
|
resample_times = inpaint_resample_times if is_inpaint else 1
|
||||||
|
|
||||||
|
if is_inpaint:
|
||||||
inpaint_image = self.normalize_img(inpaint_image)
|
inpaint_image = self.normalize_img(inpaint_image)
|
||||||
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
|
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
|
||||||
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
|
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
|
||||||
@@ -2449,31 +2517,40 @@ class Decoder(nn.Module):
|
|||||||
if not is_latent_diffusion:
|
if not is_latent_diffusion:
|
||||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||||
|
|
||||||
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
|
for time in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
|
||||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
is_last_timestep = time == 0
|
||||||
|
|
||||||
if exists(inpaint_image):
|
for r in reversed(range(0, resample_times)):
|
||||||
# following the repaint paper
|
is_last_resample_step = r == 0
|
||||||
# https://arxiv.org/abs/2201.09865
|
|
||||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
|
|
||||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
|
||||||
|
|
||||||
img = self.p_sample(
|
times = torch.full((b,), time, device = device, dtype = torch.long)
|
||||||
unet,
|
|
||||||
img,
|
|
||||||
times,
|
|
||||||
image_embed = image_embed,
|
|
||||||
text_encodings = text_encodings,
|
|
||||||
cond_scale = cond_scale,
|
|
||||||
lowres_cond_img = lowres_cond_img,
|
|
||||||
lowres_noise_level = lowres_noise_level,
|
|
||||||
predict_x_start = predict_x_start,
|
|
||||||
noise_scheduler = noise_scheduler,
|
|
||||||
learned_variance = learned_variance,
|
|
||||||
clip_denoised = clip_denoised
|
|
||||||
)
|
|
||||||
|
|
||||||
if exists(inpaint_image):
|
if is_inpaint:
|
||||||
|
# following the repaint paper
|
||||||
|
# https://arxiv.org/abs/2201.09865
|
||||||
|
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
|
||||||
|
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||||
|
|
||||||
|
img = self.p_sample(
|
||||||
|
unet,
|
||||||
|
img,
|
||||||
|
times,
|
||||||
|
image_embed = image_embed,
|
||||||
|
text_encodings = text_encodings,
|
||||||
|
cond_scale = cond_scale,
|
||||||
|
lowres_cond_img = lowres_cond_img,
|
||||||
|
lowres_noise_level = lowres_noise_level,
|
||||||
|
predict_x_start = predict_x_start,
|
||||||
|
noise_scheduler = noise_scheduler,
|
||||||
|
learned_variance = learned_variance,
|
||||||
|
clip_denoised = clip_denoised
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_inpaint and not (is_last_timestep or is_last_resample_step):
|
||||||
|
# in repaint, you renoise and resample up to 10 times every step
|
||||||
|
img = noise_scheduler.q_sample_from_to(img, times - 1, times)
|
||||||
|
|
||||||
|
if is_inpaint:
|
||||||
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
|
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
|
||||||
|
|
||||||
unnormalize_img = self.unnormalize_img(img)
|
unnormalize_img = self.unnormalize_img(img)
|
||||||
@@ -2497,7 +2574,8 @@ class Decoder(nn.Module):
|
|||||||
is_latent_diffusion = False,
|
is_latent_diffusion = False,
|
||||||
lowres_noise_level = None,
|
lowres_noise_level = None,
|
||||||
inpaint_image = None,
|
inpaint_image = None,
|
||||||
inpaint_mask = None
|
inpaint_mask = None,
|
||||||
|
inpaint_resample_times = 5
|
||||||
):
|
):
|
||||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
||||||
|
|
||||||
@@ -2506,7 +2584,10 @@ class Decoder(nn.Module):
|
|||||||
times = list(reversed(times.int().tolist()))
|
times = list(reversed(times.int().tolist()))
|
||||||
time_pairs = list(zip(times[:-1], times[1:]))
|
time_pairs = list(zip(times[:-1], times[1:]))
|
||||||
|
|
||||||
if exists(inpaint_image):
|
is_inpaint = exists(inpaint_image)
|
||||||
|
resample_times = inpaint_resample_times if is_inpaint else 1
|
||||||
|
|
||||||
|
if is_inpaint:
|
||||||
inpaint_image = self.normalize_img(inpaint_image)
|
inpaint_image = self.normalize_img(inpaint_image)
|
||||||
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
|
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
|
||||||
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
|
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
|
||||||
@@ -2519,39 +2600,49 @@ class Decoder(nn.Module):
|
|||||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||||
|
|
||||||
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
||||||
alpha = alphas[time]
|
is_last_timestep = time_next == 0
|
||||||
alpha_next = alphas[time_next]
|
|
||||||
|
|
||||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
for r in reversed(range(0, resample_times)):
|
||||||
|
is_last_resample_step = r == 0
|
||||||
|
|
||||||
if exists(inpaint_image):
|
alpha = alphas[time]
|
||||||
# following the repaint paper
|
alpha_next = alphas[time_next]
|
||||||
# https://arxiv.org/abs/2201.09865
|
|
||||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
|
|
||||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
|
||||||
|
|
||||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||||
|
|
||||||
if learned_variance:
|
if is_inpaint:
|
||||||
pred, _ = pred.chunk(2, dim = 1)
|
# following the repaint paper
|
||||||
|
# https://arxiv.org/abs/2201.09865
|
||||||
|
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
|
||||||
|
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||||
|
|
||||||
if predict_x_start:
|
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||||
x_start = pred
|
|
||||||
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
|
||||||
else:
|
|
||||||
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
|
||||||
pred_noise = pred
|
|
||||||
|
|
||||||
if clip_denoised:
|
if learned_variance:
|
||||||
x_start = self.dynamic_threshold(x_start)
|
pred, _ = pred.chunk(2, dim = 1)
|
||||||
|
|
||||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
if predict_x_start:
|
||||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
x_start = pred
|
||||||
noise = torch.randn_like(img) if time_next > 0 else 0.
|
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
||||||
|
else:
|
||||||
|
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
||||||
|
pred_noise = pred
|
||||||
|
|
||||||
img = x_start * alpha_next.sqrt() + \
|
if clip_denoised:
|
||||||
c1 * noise + \
|
x_start = self.dynamic_threshold(x_start)
|
||||||
c2 * pred_noise
|
|
||||||
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||||
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||||
|
noise = torch.randn_like(img) if not is_last_timestep else 0.
|
||||||
|
|
||||||
|
img = x_start * alpha_next.sqrt() + \
|
||||||
|
c1 * noise + \
|
||||||
|
c2 * pred_noise
|
||||||
|
|
||||||
|
if is_inpaint and not (is_last_timestep or is_last_resample_step):
|
||||||
|
# in repaint, you renoise and resample up to 10 times every step
|
||||||
|
time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)
|
||||||
|
img = noise_scheduler.q_sample_from_to(img, time_next_cond, time_cond)
|
||||||
|
|
||||||
if exists(inpaint_image):
|
if exists(inpaint_image):
|
||||||
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
|
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
|
||||||
@@ -2658,7 +2749,8 @@ class Decoder(nn.Module):
|
|||||||
stop_at_unet_number = None,
|
stop_at_unet_number = None,
|
||||||
distributed = False,
|
distributed = False,
|
||||||
inpaint_image = None,
|
inpaint_image = None,
|
||||||
inpaint_mask = None
|
inpaint_mask = None,
|
||||||
|
inpaint_resample_times = 5
|
||||||
):
|
):
|
||||||
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
|
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
|
||||||
|
|
||||||
@@ -2730,7 +2822,8 @@ class Decoder(nn.Module):
|
|||||||
noise_scheduler = noise_scheduler,
|
noise_scheduler = noise_scheduler,
|
||||||
timesteps = sample_timesteps,
|
timesteps = sample_timesteps,
|
||||||
inpaint_image = inpaint_image,
|
inpaint_image = inpaint_image,
|
||||||
inpaint_mask = inpaint_mask
|
inpaint_mask = inpaint_mask,
|
||||||
|
inpaint_resample_times = inpaint_resample_times
|
||||||
)
|
)
|
||||||
|
|
||||||
img = vae.decode(img)
|
img = vae.decode(img)
|
||||||
@@ -2845,7 +2938,7 @@ class DALLE2(nn.Module):
|
|||||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
|
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
|
||||||
|
|
||||||
text_cond = text if self.decoder_need_text_cond else None
|
text_cond = text if self.decoder_need_text_cond else None
|
||||||
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)
|
||||||
|
|
||||||
if return_pil_images:
|
if return_pil_images:
|
||||||
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
||||||
|
|||||||
@@ -67,6 +67,15 @@ class PriorEmbeddingDataset(IterableDataset):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
|
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
|
||||||
|
|
||||||
|
def set_start(self, start):
|
||||||
|
"""
|
||||||
|
Adjust the starting point within the reader, useful for resuming an epoch
|
||||||
|
"""
|
||||||
|
self.start = start
|
||||||
|
|
||||||
|
def get_start(self):
|
||||||
|
return self.start
|
||||||
|
|
||||||
def get_sample(self):
|
def get_sample(self):
|
||||||
"""
|
"""
|
||||||
pre-proocess data from either reader into a common format
|
pre-proocess data from either reader into a common format
|
||||||
|
|||||||
@@ -528,7 +528,7 @@ class Tracker:
|
|||||||
elif save_type == 'model':
|
elif save_type == 'model':
|
||||||
if isinstance(trainer, DiffusionPriorTrainer):
|
if isinstance(trainer, DiffusionPriorTrainer):
|
||||||
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||||
prior: DiffusionPrior = trainer.unwrap_model(prior)
|
prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)
|
||||||
# Remove CLIP if it is part of the model
|
# Remove CLIP if it is part of the model
|
||||||
original_clip = prior.clip
|
original_clip = prior.clip
|
||||||
prior.clip = None
|
prior.clip = None
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from torchvision import transforms as T
|
from torchvision import transforms as T
|
||||||
from pydantic import BaseModel, validator, root_validator
|
from pydantic import BaseModel, validator, root_validator
|
||||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
|
||||||
|
|
||||||
from x_clip import CLIP as XCLIP
|
from x_clip import CLIP as XCLIP
|
||||||
from coca_pytorch import CoCa
|
from coca_pytorch import CoCa
|
||||||
@@ -25,11 +25,9 @@ def exists(val):
|
|||||||
def default(val, d):
|
def default(val, d):
|
||||||
return val if exists(val) else d
|
return val if exists(val) else d
|
||||||
|
|
||||||
def ListOrTuple(inner_type):
|
InnerType = TypeVar('InnerType')
|
||||||
return Union[List[inner_type], Tuple[inner_type]]
|
ListOrTuple = Union[List[InnerType], Tuple[InnerType]]
|
||||||
|
SingularOrIterable = Union[InnerType, ListOrTuple[InnerType]]
|
||||||
def SingularOrIterable(inner_type):
|
|
||||||
return Union[inner_type, ListOrTuple(inner_type)]
|
|
||||||
|
|
||||||
# general pydantic classes
|
# general pydantic classes
|
||||||
|
|
||||||
@@ -145,6 +143,9 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
|||||||
normformer: bool = False
|
normformer: bool = False
|
||||||
rotary_emb: bool = True
|
rotary_emb: bool = True
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
kwargs = self.dict()
|
kwargs = self.dict()
|
||||||
return DiffusionPriorNetwork(**kwargs)
|
return DiffusionPriorNetwork(**kwargs)
|
||||||
@@ -187,23 +188,26 @@ class DiffusionPriorTrainConfig(BaseModel):
|
|||||||
use_ema: bool = True
|
use_ema: bool = True
|
||||||
ema_beta: float = 0.99
|
ema_beta: float = 0.99
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
save_every: int = 10000 # what steps to save on
|
warmup_steps: int = None # number of warmup steps
|
||||||
|
save_every_seconds: int = 3600 # how often to save
|
||||||
|
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
|
||||||
|
best_validation_loss: float = 1e9 # the current best valudation loss observed
|
||||||
|
current_epoch: int = 0 # the current epoch
|
||||||
|
num_samples_seen: int = 0 # the current number of samples seen
|
||||||
|
random_seed: int = 0 # manual seed for torch
|
||||||
|
|
||||||
class DiffusionPriorDataConfig(BaseModel):
|
class DiffusionPriorDataConfig(BaseModel):
|
||||||
image_url: str # path to embeddings folder
|
image_url: str # path to embeddings folder
|
||||||
meta_url: str # path to metadata (captions) for images
|
meta_url: str # path to metadata (captions) for images
|
||||||
splits: TrainSplitConfig
|
splits: TrainSplitConfig # define train, validation, test splits for your dataset
|
||||||
batch_size: int = 64
|
batch_size: int # per-gpu batch size used to train the model
|
||||||
|
num_data_points: int = 25e7 # total number of datapoints to train on
|
||||||
class DiffusionPriorLoadConfig(BaseModel):
|
eval_every_seconds: int = 3600 # validation statistics will be performed this often
|
||||||
source: str = None
|
|
||||||
resume: bool = False
|
|
||||||
|
|
||||||
class TrainDiffusionPriorConfig(BaseModel):
|
class TrainDiffusionPriorConfig(BaseModel):
|
||||||
prior: DiffusionPriorConfig
|
prior: DiffusionPriorConfig
|
||||||
data: DiffusionPriorDataConfig
|
data: DiffusionPriorDataConfig
|
||||||
train: DiffusionPriorTrainConfig
|
train: DiffusionPriorTrainConfig
|
||||||
load: DiffusionPriorLoadConfig
|
|
||||||
tracker: TrackerConfig
|
tracker: TrackerConfig
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -216,13 +220,13 @@ class TrainDiffusionPriorConfig(BaseModel):
|
|||||||
|
|
||||||
class UnetConfig(BaseModel):
|
class UnetConfig(BaseModel):
|
||||||
dim: int
|
dim: int
|
||||||
dim_mults: ListOrTuple(int)
|
dim_mults: ListOrTuple[int]
|
||||||
image_embed_dim: int = None
|
image_embed_dim: int = None
|
||||||
text_embed_dim: int = None
|
text_embed_dim: int = None
|
||||||
cond_on_text_encodings: bool = None
|
cond_on_text_encodings: bool = None
|
||||||
cond_dim: int = None
|
cond_dim: int = None
|
||||||
channels: int = 3
|
channels: int = 3
|
||||||
self_attn: ListOrTuple(int)
|
self_attn: ListOrTuple[int]
|
||||||
attn_dim_head: int = 32
|
attn_dim_head: int = 32
|
||||||
attn_heads: int = 16
|
attn_heads: int = 16
|
||||||
init_cross_embed: bool = True
|
init_cross_embed: bool = True
|
||||||
@@ -231,16 +235,16 @@ class UnetConfig(BaseModel):
|
|||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
class DecoderConfig(BaseModel):
|
class DecoderConfig(BaseModel):
|
||||||
unets: ListOrTuple(UnetConfig)
|
unets: ListOrTuple[UnetConfig]
|
||||||
image_size: int = None
|
image_size: int = None
|
||||||
image_sizes: ListOrTuple(int) = None
|
image_sizes: ListOrTuple[int] = None
|
||||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||||
channels: int = 3
|
channels: int = 3
|
||||||
timesteps: int = 1000
|
timesteps: int = 1000
|
||||||
sample_timesteps: Optional[SingularOrIterable(int)] = None
|
sample_timesteps: Optional[SingularOrIterable[int]] = None
|
||||||
loss_type: str = 'l2'
|
loss_type: str = 'l2'
|
||||||
beta_schedule: ListOrTuple(str) = 'cosine'
|
beta_schedule: ListOrTuple[str] = None # None means all cosine
|
||||||
learned_variance: bool = True
|
learned_variance: SingularOrIterable[bool] = True
|
||||||
image_cond_drop_prob: float = 0.1
|
image_cond_drop_prob: float = 0.1
|
||||||
text_cond_drop_prob: float = 0.5
|
text_cond_drop_prob: float = 0.5
|
||||||
|
|
||||||
@@ -299,11 +303,11 @@ class DecoderDataConfig(BaseModel):
|
|||||||
|
|
||||||
class DecoderTrainConfig(BaseModel):
|
class DecoderTrainConfig(BaseModel):
|
||||||
epochs: int = 20
|
epochs: int = 20
|
||||||
lr: SingularOrIterable(float) = 1e-4
|
lr: SingularOrIterable[float] = 1e-4
|
||||||
wd: SingularOrIterable(float) = 0.01
|
wd: SingularOrIterable[float] = 0.01
|
||||||
warmup_steps: Optional[SingularOrIterable(int)] = None
|
warmup_steps: Optional[SingularOrIterable[int]] = None
|
||||||
find_unused_parameters: bool = True
|
find_unused_parameters: bool = True
|
||||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
max_grad_norm: SingularOrIterable[float] = 0.5
|
||||||
save_every_n_samples: int = 100000
|
save_every_n_samples: int = 100000
|
||||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||||
cond_scale: Union[float, List[float]] = 1.0
|
cond_scale: Union[float, List[float]] = 1.0
|
||||||
@@ -314,7 +318,7 @@ class DecoderTrainConfig(BaseModel):
|
|||||||
use_ema: bool = True
|
use_ema: bool = True
|
||||||
ema_beta: float = 0.999
|
ema_beta: float = 0.999
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
unet_training_mask: ListOrTuple[bool] = None # If None, use all unets
|
||||||
|
|
||||||
class DecoderEvaluateConfig(BaseModel):
|
class DecoderEvaluateConfig(BaseModel):
|
||||||
n_evaluation_samples: int = 1000
|
n_evaluation_samples: int = 1000
|
||||||
@@ -323,12 +327,6 @@ class DecoderEvaluateConfig(BaseModel):
|
|||||||
KID: Dict[str, Any] = None
|
KID: Dict[str, Any] = None
|
||||||
LPIPS: Dict[str, Any] = None
|
LPIPS: Dict[str, Any] = None
|
||||||
|
|
||||||
class DecoderLoadConfig(BaseModel):
|
|
||||||
source: str = None # Supports file and wandb
|
|
||||||
run_path: str = '' # Used only if source is wandb
|
|
||||||
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
|
|
||||||
resume: bool = False # If using wandb, whether to resume the run
|
|
||||||
|
|
||||||
class TrainDecoderConfig(BaseModel):
|
class TrainDecoderConfig(BaseModel):
|
||||||
decoder: DecoderConfig
|
decoder: DecoderConfig
|
||||||
data: DecoderDataConfig
|
data: DecoderDataConfig
|
||||||
|
|||||||
@@ -174,26 +174,24 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
diffusion_prior,
|
diffusion_prior,
|
||||||
|
accelerator = None,
|
||||||
use_ema = True,
|
use_ema = True,
|
||||||
lr = 3e-4,
|
lr = 3e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
eps = 1e-6,
|
eps = 1e-6,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
amp = False,
|
|
||||||
group_wd_params = True,
|
group_wd_params = True,
|
||||||
device = None,
|
warmup_steps = 1,
|
||||||
accelerator = None,
|
|
||||||
verbose = True,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||||
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
|
|
||||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||||
|
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
|
||||||
|
|
||||||
# verbosity
|
if not exists(accelerator):
|
||||||
|
accelerator = Accelerator(**accelerator_kwargs)
|
||||||
self.verbose = verbose
|
|
||||||
|
|
||||||
# assign some helpful member vars
|
# assign some helpful member vars
|
||||||
|
|
||||||
@@ -202,23 +200,31 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
# setting the device
|
# setting the device
|
||||||
|
|
||||||
if not exists(accelerator) and not exists(device):
|
self.device = accelerator.device
|
||||||
diffusion_prior_device = next(diffusion_prior.parameters()).device
|
diffusion_prior.to(self.device)
|
||||||
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
|
|
||||||
self.device = diffusion_prior_device
|
|
||||||
else:
|
|
||||||
self.device = accelerator.device if exists(accelerator) else device
|
|
||||||
diffusion_prior.to(self.device)
|
|
||||||
|
|
||||||
# save model
|
# save model
|
||||||
|
|
||||||
self.diffusion_prior = diffusion_prior
|
self.diffusion_prior = diffusion_prior
|
||||||
|
|
||||||
# optimizer and mixed precision stuff
|
# mixed precision checks
|
||||||
|
|
||||||
self.amp = amp
|
if (
|
||||||
|
exists(self.accelerator)
|
||||||
|
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
|
||||||
|
and self.diffusion_prior.clip is not None
|
||||||
|
):
|
||||||
|
# Then we need to make sure clip is using the correct precision or else deepspeed will error
|
||||||
|
cast_type_map = {
|
||||||
|
"fp16": torch.half,
|
||||||
|
"bf16": torch.bfloat16,
|
||||||
|
"no": torch.float
|
||||||
|
}
|
||||||
|
precision_type = cast_type_map[accelerator.mixed_precision]
|
||||||
|
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
|
||||||
|
self.diffusion_prior.clip.to(precision_type)
|
||||||
|
|
||||||
self.scaler = GradScaler(enabled = amp)
|
# optimizer stuff
|
||||||
|
|
||||||
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
|
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
|
||||||
|
|
||||||
@@ -227,17 +233,21 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
**self.optim_kwargs,
|
**self.optim_kwargs,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||||
|
|
||||||
|
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
||||||
|
|
||||||
# distribute the model if using HFA
|
# distribute the model if using HFA
|
||||||
if exists(self.accelerator):
|
|
||||||
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
|
self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler)
|
||||||
|
|
||||||
# exponential moving average stuff
|
# exponential moving average stuff
|
||||||
|
|
||||||
self.use_ema = use_ema
|
self.use_ema = use_ema
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
|
self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs)
|
||||||
|
|
||||||
# gradient clipping if needed
|
# gradient clipping if needed
|
||||||
|
|
||||||
@@ -247,67 +257,24 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0], device = self.device))
|
self.register_buffer('step', torch.tensor([0], device = self.device))
|
||||||
|
|
||||||
# accelerator wrappers
|
|
||||||
|
|
||||||
def print(self, msg):
|
|
||||||
if not self.verbose:
|
|
||||||
return
|
|
||||||
|
|
||||||
if exists(self.accelerator):
|
|
||||||
self.accelerator.print(msg)
|
|
||||||
else:
|
|
||||||
print(msg)
|
|
||||||
|
|
||||||
def unwrap_model(self, model):
|
|
||||||
if exists(self.accelerator):
|
|
||||||
return self.accelerator.unwrap_model(model)
|
|
||||||
else:
|
|
||||||
return model
|
|
||||||
|
|
||||||
def wait_for_everyone(self):
|
|
||||||
if exists(self.accelerator):
|
|
||||||
self.accelerator.wait_for_everyone()
|
|
||||||
|
|
||||||
def is_main_process(self):
|
|
||||||
if exists(self.accelerator):
|
|
||||||
return self.accelerator.is_main_process
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def clip_grad_norm_(self, *args):
|
|
||||||
if exists(self.accelerator):
|
|
||||||
return self.accelerator.clip_grad_norm_(*args)
|
|
||||||
else:
|
|
||||||
return torch.nn.utils.clip_grad_norm_(*args)
|
|
||||||
|
|
||||||
def backprop(self, x):
|
|
||||||
if exists(self.accelerator):
|
|
||||||
self.accelerator.backward(x)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
x.backward()
|
|
||||||
except Exception as e:
|
|
||||||
self.print(f"Caught error in backprop call: {e}")
|
|
||||||
|
|
||||||
# utility
|
# utility
|
||||||
|
|
||||||
def save(self, path, overwrite = True, **kwargs):
|
def save(self, path, overwrite = True, **kwargs):
|
||||||
# ensure we sync gradients before continuing
|
|
||||||
self.wait_for_everyone()
|
|
||||||
|
|
||||||
# only save on the main process
|
# only save on the main process
|
||||||
if self.is_main_process():
|
if self.accelerator.is_main_process:
|
||||||
self.print(f"Saving checkpoint at step: {self.step.item()}")
|
print(f"Saving checkpoint at step: {self.step.item()}")
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
assert not (path.exists() and not overwrite)
|
assert not (path.exists() and not overwrite)
|
||||||
path.parent.mkdir(parents = True, exist_ok = True)
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
# FIXME: LambdaLR can't be saved due to pickling issues
|
||||||
save_obj = dict(
|
save_obj = dict(
|
||||||
scaler = self.scaler.state_dict(),
|
|
||||||
optimizer = self.optimizer.state_dict(),
|
optimizer = self.optimizer.state_dict(),
|
||||||
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
|
warmup_scheduler = self.warmup_scheduler,
|
||||||
|
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
||||||
version = version.parse(__version__),
|
version = version.parse(__version__),
|
||||||
step = self.step.item(),
|
step = self.step,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -320,14 +287,14 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
torch.save(save_obj, str(path))
|
torch.save(save_obj, str(path))
|
||||||
|
|
||||||
def load(self, path, overwrite_lr = True, strict = True):
|
def load(self, path_or_state, overwrite_lr = True, strict = True):
|
||||||
"""
|
"""
|
||||||
Load a checkpoint of a diffusion prior trainer.
|
Load a checkpoint of a diffusion prior trainer.
|
||||||
|
|
||||||
Will load the entire trainer, including the optimizer and EMA.
|
Will load the entire trainer, including the optimizer and EMA.
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
- path (str): a path to the DiffusionPriorTrainer checkpoint file
|
- path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file
|
||||||
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
|
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
|
||||||
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
|
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
|
||||||
|
|
||||||
@@ -336,56 +303,56 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# all processes need to load checkpoint. no restriction here
|
# all processes need to load checkpoint. no restriction here
|
||||||
path = Path(path)
|
if isinstance(path_or_state, str):
|
||||||
assert path.exists()
|
path = Path(path_or_state)
|
||||||
|
assert path.exists()
|
||||||
|
loaded_obj = torch.load(str(path), map_location=self.device)
|
||||||
|
|
||||||
loaded_obj = torch.load(str(path), map_location=self.device)
|
elif isinstance(path_or_state, dict):
|
||||||
|
loaded_obj = path_or_state
|
||||||
|
|
||||||
if version.parse(__version__) != loaded_obj['version']:
|
if version.parse(__version__) != loaded_obj['version']:
|
||||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
|
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
|
||||||
|
|
||||||
# unwrap the model when loading from checkpoint
|
# unwrap the model when loading from checkpoint
|
||||||
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
||||||
|
|
||||||
self.scaler.load_state_dict(loaded_obj['scaler'])
|
|
||||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||||
|
|
||||||
|
# set warmupstep
|
||||||
|
if exists(self.warmup_scheduler):
|
||||||
|
self.warmup_scheduler.last_step = self.step.item()
|
||||||
|
|
||||||
|
# ensure new lr is used if different from old one
|
||||||
if overwrite_lr:
|
if overwrite_lr:
|
||||||
new_lr = self.optim_kwargs["lr"]
|
new_lr = self.optim_kwargs["lr"]
|
||||||
|
|
||||||
self.print(f"Overriding LR to be {new_lr}")
|
|
||||||
|
|
||||||
for group in self.optimizer.param_groups:
|
for group in self.optimizer.param_groups:
|
||||||
group["lr"] = new_lr
|
group["lr"] = new_lr if group["lr"] > 0.0 else 0.0
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
assert 'ema' in loaded_obj
|
assert 'ema' in loaded_obj
|
||||||
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
|
# below might not be necessary, but I had a suspicion that this wasn't being loaded correctly
|
||||||
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
|
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
|
||||||
|
|
||||||
# sync and inform
|
|
||||||
self.wait_for_everyone()
|
|
||||||
self.print(f"Loaded model")
|
|
||||||
|
|
||||||
return loaded_obj
|
return loaded_obj
|
||||||
|
|
||||||
# model functionality
|
# model functionality
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
# only continue with updates until all ranks finish
|
|
||||||
self.wait_for_everyone()
|
|
||||||
|
|
||||||
if exists(self.max_grad_norm):
|
if exists(self.max_grad_norm):
|
||||||
self.scaler.unscale_(self.optimizer)
|
self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||||
# utilize HFA clipping where applicable
|
|
||||||
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
self.optimizer.step()
|
||||||
|
|
||||||
self.scaler.step(self.optimizer)
|
|
||||||
self.scaler.update()
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
||||||
|
if not self.accelerator.optimizer_step_was_skipped:
|
||||||
|
with self.warmup_scheduler.dampening():
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.ema_diffusion_prior.update()
|
self.ema_diffusion_prior.update()
|
||||||
|
|
||||||
@@ -414,7 +381,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
@prior_sample_in_chunks
|
@prior_sample_in_chunks
|
||||||
def embed_text(self, *args, **kwargs):
|
def embed_text(self, *args, **kwargs):
|
||||||
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
|
return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
|
||||||
|
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
def forward(
|
def forward(
|
||||||
@@ -426,16 +393,14 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
total_loss = 0.
|
total_loss = 0.
|
||||||
|
|
||||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||||
with autocast(enabled = self.amp):
|
with self.accelerator.autocast():
|
||||||
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
||||||
loss = loss * chunk_size_frac
|
loss = loss * chunk_size_frac
|
||||||
|
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
# backprop with accelerate if applicable
|
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
self.backprop(self.scaler.scale(loss))
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.0.0'
|
__version__ = '1.2.2'
|
||||||
|
|||||||
@@ -1,31 +1,23 @@
|
|||||||
# TODO: add start, num_data_points, eval_every and group to config
|
|
||||||
# TODO: switch back to repo's wandb
|
|
||||||
|
|
||||||
START = 0
|
|
||||||
NUM_DATA_POINTS = 250e6
|
|
||||||
EVAL_EVERY = 1000
|
|
||||||
GROUP = "distributed"
|
|
||||||
|
|
||||||
import os
|
|
||||||
import click
|
import click
|
||||||
import wandb
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from embedding_reader import EmbeddingReader
|
||||||
|
from accelerate.utils import dataclasses as accelerate_dataclasses
|
||||||
|
|
||||||
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
|
||||||
from dalle2_pytorch.utils import Timer
|
from dalle2_pytorch.utils import Timer
|
||||||
|
from dalle2_pytorch.trackers import Tracker
|
||||||
|
from dalle2_pytorch import DiffusionPriorTrainer
|
||||||
|
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||||
from dalle2_pytorch.train_configs import (
|
from dalle2_pytorch.train_configs import (
|
||||||
|
DiffusionPriorConfig,
|
||||||
DiffusionPriorTrainConfig,
|
DiffusionPriorTrainConfig,
|
||||||
TrainDiffusionPriorConfig,
|
TrainDiffusionPriorConfig,
|
||||||
)
|
)
|
||||||
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
|
|
||||||
from dalle2_pytorch import DiffusionPriorTrainer
|
|
||||||
|
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
@@ -38,8 +30,19 @@ def exists(val):
|
|||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def all_between(values: list, lower_bound, upper_bound):
|
||||||
|
for value in values:
|
||||||
|
if value < lower_bound or value > upper_bound:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def make_model(
|
def make_model(
|
||||||
prior_config, train_config, device: str = None, accelerator: Accelerator = None
|
prior_config: DiffusionPriorConfig,
|
||||||
|
train_config: DiffusionPriorTrainConfig,
|
||||||
|
device: str = None,
|
||||||
|
accelerator: Accelerator = None,
|
||||||
):
|
):
|
||||||
# create model from config
|
# create model from config
|
||||||
diffusion_prior = prior_config.create()
|
diffusion_prior = prior_config.create()
|
||||||
@@ -54,71 +57,214 @@ def make_model(
|
|||||||
use_ema=train_config.use_ema,
|
use_ema=train_config.use_ema,
|
||||||
device=device,
|
device=device,
|
||||||
accelerator=accelerator,
|
accelerator=accelerator,
|
||||||
|
warmup_steps=train_config.warmup_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
|
def create_tracker(
|
||||||
|
accelerator: Accelerator,
|
||||||
|
config: TrainDiffusionPriorConfig,
|
||||||
|
config_path: str,
|
||||||
|
dummy: bool = False,
|
||||||
|
) -> Tracker:
|
||||||
|
tracker_config = config.tracker
|
||||||
|
|
||||||
|
accelerator_config = {
|
||||||
|
"Distributed": accelerator.distributed_type
|
||||||
|
!= accelerate_dataclasses.DistributedType.NO,
|
||||||
|
"DistributedType": accelerator.distributed_type,
|
||||||
|
"NumProcesses": accelerator.num_processes,
|
||||||
|
"MixedPrecision": accelerator.mixed_precision,
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker: Tracker = tracker_config.create(
|
||||||
|
config, accelerator_config, dummy_mode=dummy
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker.save_config(config_path, config_name="prior_config.json")
|
||||||
|
|
||||||
|
return tracker
|
||||||
|
|
||||||
|
|
||||||
|
def pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method="mean"):
|
||||||
|
"""
|
||||||
|
pad a value or tensor across all processes and gather
|
||||||
|
|
||||||
|
params:
|
||||||
|
- trainer: a trainer that carries an accelerator object
|
||||||
|
- x: a number or torch tensor to reduce
|
||||||
|
- method: "mean", "sum", "max", "min"
|
||||||
|
|
||||||
|
return:
|
||||||
|
- the average tensor after maskin out 0's
|
||||||
|
- None if the gather resulted in an empty tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert method in [
|
||||||
|
"mean",
|
||||||
|
"sum",
|
||||||
|
"max",
|
||||||
|
"min",
|
||||||
|
], "This function has limited capabilities [sum, mean, max, min]"
|
||||||
|
assert type(x) is not None, "Cannot reduce a None type object"
|
||||||
|
|
||||||
|
# wait for everyone to arrive here before gathering
|
||||||
|
|
||||||
|
if type(x) is not torch.Tensor:
|
||||||
|
x = torch.tensor([x])
|
||||||
|
|
||||||
|
# verify that the tensor is on the proper device
|
||||||
|
x = x.to(trainer.device)
|
||||||
|
|
||||||
|
# pad across processes
|
||||||
|
padded_x = trainer.accelerator.pad_across_processes(x, dim=0)
|
||||||
|
|
||||||
|
# gather across all procesess
|
||||||
|
gathered_x = trainer.accelerator.gather(padded_x)
|
||||||
|
|
||||||
|
# mask out zeros
|
||||||
|
masked_x = gathered_x[gathered_x != 0]
|
||||||
|
|
||||||
|
# if the tensor is empty, warn and return None
|
||||||
|
if len(masked_x) == 0:
|
||||||
|
click.secho(
|
||||||
|
f"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.",
|
||||||
|
fg="red",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if method == "mean":
|
||||||
|
return torch.mean(masked_x)
|
||||||
|
elif method == "sum":
|
||||||
|
return torch.sum(masked_x)
|
||||||
|
elif method == "max":
|
||||||
|
return torch.max(masked_x)
|
||||||
|
elif method == "min":
|
||||||
|
return torch.min(masked_x)
|
||||||
|
|
||||||
|
|
||||||
|
def save_trainer(
|
||||||
|
tracker: Tracker,
|
||||||
|
trainer: DiffusionPriorTrainer,
|
||||||
|
is_latest: bool,
|
||||||
|
is_best: bool,
|
||||||
|
epoch: int,
|
||||||
|
samples_seen: int,
|
||||||
|
best_validation_loss: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Logs the model with an appropriate method depending on the tracker
|
||||||
|
"""
|
||||||
|
trainer.accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
if trainer.accelerator.is_main_process:
|
||||||
|
click.secho(
|
||||||
|
f"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}",
|
||||||
|
fg="magenta",
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker.save(
|
||||||
|
trainer=trainer,
|
||||||
|
is_best=is_best,
|
||||||
|
is_latest=is_latest,
|
||||||
|
epoch=int(epoch),
|
||||||
|
samples_seen=int(samples_seen),
|
||||||
|
best_validation_loss=best_validation_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer):
|
||||||
|
"""
|
||||||
|
Loads the model with an appropriate method depending on the tracker
|
||||||
|
"""
|
||||||
|
|
||||||
|
if trainer.accelerator.is_main_process:
|
||||||
|
click.secho(f"Loading model from {type(tracker.loader).__name__}", fg="yellow")
|
||||||
|
|
||||||
|
state_dict = tracker.recall()
|
||||||
|
|
||||||
|
trainer.load(state_dict, strict=True)
|
||||||
|
|
||||||
|
return (
|
||||||
|
int(state_dict.get("epoch", 0)),
|
||||||
|
state_dict.get("best_validation_loss", 0),
|
||||||
|
int(state_dict.get("samples_seen", 0)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# eval functions
|
# eval functions
|
||||||
|
|
||||||
|
|
||||||
def eval_model(
|
def report_validation_loss(
|
||||||
trainer: DiffusionPriorTrainer,
|
trainer: DiffusionPriorTrainer,
|
||||||
dataloader: DataLoader,
|
dataloader: DataLoader,
|
||||||
text_conditioned: bool,
|
text_conditioned: bool,
|
||||||
|
use_ema: bool,
|
||||||
|
tracker: Tracker,
|
||||||
|
split: str,
|
||||||
|
tracker_folder: str,
|
||||||
loss_type: str,
|
loss_type: str,
|
||||||
tracker_context: str,
|
|
||||||
tracker: BaseTracker = None,
|
|
||||||
use_ema: bool = True,
|
|
||||||
):
|
):
|
||||||
trainer.eval()
|
"""
|
||||||
if trainer.is_main_process():
|
Compute the validation loss on a given subset of data.
|
||||||
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
|
"""
|
||||||
|
|
||||||
with torch.no_grad():
|
if trainer.accelerator.is_main_process:
|
||||||
total_loss = 0.0
|
click.secho(
|
||||||
total_samples = 0.0
|
f"Measuring performance on {use_ema}-{split} split",
|
||||||
|
fg="green",
|
||||||
|
blink=True,
|
||||||
|
)
|
||||||
|
|
||||||
for image_embeddings, text_data in dataloader:
|
total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device)
|
||||||
image_embeddings = image_embeddings.to(trainer.device)
|
|
||||||
text_data = text_data.to(trainer.device)
|
|
||||||
|
|
||||||
batches = image_embeddings.shape[0]
|
for image_embeddings, text_data in dataloader:
|
||||||
|
image_embeddings = image_embeddings.to(trainer.device)
|
||||||
|
text_data = text_data.to(trainer.device)
|
||||||
|
|
||||||
input_args = dict(image_embed=image_embeddings)
|
input_args = dict(image_embed=image_embeddings)
|
||||||
|
|
||||||
if text_conditioned:
|
if text_conditioned:
|
||||||
input_args = dict(**input_args, text=text_data)
|
input_args = dict(**input_args, text=text_data)
|
||||||
else:
|
else:
|
||||||
input_args = dict(**input_args, text_embed=text_data)
|
input_args = dict(**input_args, text_embed=text_data)
|
||||||
|
|
||||||
if use_ema:
|
if use_ema:
|
||||||
loss = trainer.ema_diffusion_prior(**input_args)
|
loss = trainer.ema_diffusion_prior(**input_args)
|
||||||
else:
|
else:
|
||||||
loss = trainer(**input_args)
|
loss = trainer(**input_args)
|
||||||
|
|
||||||
total_loss += loss * batches
|
total_loss += loss
|
||||||
total_samples += batches
|
|
||||||
|
|
||||||
avg_loss = total_loss / total_samples
|
# compute the average loss across all processes
|
||||||
|
|
||||||
stats = {f"{tracker_context}-{loss_type}": avg_loss}
|
avg_loss = pad_gather_reduce(trainer, total_loss, method="mean")
|
||||||
trainer.print(stats)
|
stats = {f"{tracker_folder}/{loss_type}-loss": avg_loss}
|
||||||
|
|
||||||
if exists(tracker):
|
# print and log results on main process
|
||||||
tracker.log(stats, step=trainer.step.item() + 1)
|
tracker.log(stats, step=trainer.step.item() + 1)
|
||||||
|
|
||||||
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
def report_cosine_sims(
|
def report_cosine_sims(
|
||||||
trainer: DiffusionPriorTrainer,
|
trainer: DiffusionPriorTrainer,
|
||||||
dataloader: DataLoader,
|
dataloader: DataLoader,
|
||||||
text_conditioned: bool,
|
text_conditioned: bool,
|
||||||
tracker: BaseTracker,
|
tracker: Tracker,
|
||||||
tracker_context: str = "validation",
|
split: str,
|
||||||
|
timesteps: int,
|
||||||
|
tracker_folder: str,
|
||||||
):
|
):
|
||||||
trainer.eval()
|
trainer.eval()
|
||||||
if trainer.is_main_process():
|
if trainer.accelerator.is_main_process:
|
||||||
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
|
click.secho(
|
||||||
|
f"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps",
|
||||||
|
fg="green",
|
||||||
|
blink=True,
|
||||||
|
)
|
||||||
|
|
||||||
for test_image_embeddings, text_data in dataloader:
|
for test_image_embeddings, text_data in dataloader:
|
||||||
test_image_embeddings = test_image_embeddings.to(trainer.device)
|
test_image_embeddings = test_image_embeddings.to(trainer.device)
|
||||||
@@ -127,9 +273,7 @@ def report_cosine_sims(
|
|||||||
# we are text conditioned, we produce an embedding from the tokenized text
|
# we are text conditioned, we produce an embedding from the tokenized text
|
||||||
if text_conditioned:
|
if text_conditioned:
|
||||||
text_embedding, text_encodings = trainer.embed_text(text_data)
|
text_embedding, text_encodings = trainer.embed_text(text_data)
|
||||||
text_cond = dict(
|
text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings)
|
||||||
text_embed=text_embedding, text_encodings=text_encodings
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
text_embedding = text_data
|
text_embedding = text_data
|
||||||
text_cond = dict(text_embed=text_embedding)
|
text_cond = dict(text_embed=text_embedding)
|
||||||
@@ -150,8 +294,7 @@ def report_cosine_sims(
|
|||||||
text_encodings_shuffled = None
|
text_encodings_shuffled = None
|
||||||
|
|
||||||
text_cond_shuffled = dict(
|
text_cond_shuffled = dict(
|
||||||
text_embed=text_embed_shuffled,
|
text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled
|
||||||
text_encodings=text_encodings_shuffled
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# prepare the text embedding
|
# prepare the text embedding
|
||||||
@@ -164,7 +307,9 @@ def report_cosine_sims(
|
|||||||
|
|
||||||
# predict on the unshuffled text embeddings
|
# predict on the unshuffled text embeddings
|
||||||
predicted_image_embeddings = trainer.p_sample_loop(
|
predicted_image_embeddings = trainer.p_sample_loop(
|
||||||
test_image_embeddings.shape, text_cond
|
test_image_embeddings.shape,
|
||||||
|
text_cond,
|
||||||
|
timesteps=timesteps,
|
||||||
)
|
)
|
||||||
|
|
||||||
predicted_image_embeddings = (
|
predicted_image_embeddings = (
|
||||||
@@ -174,7 +319,9 @@ def report_cosine_sims(
|
|||||||
|
|
||||||
# predict on the shuffled embeddings
|
# predict on the shuffled embeddings
|
||||||
predicted_unrelated_embeddings = trainer.p_sample_loop(
|
predicted_unrelated_embeddings = trainer.p_sample_loop(
|
||||||
test_image_embeddings.shape, text_cond_shuffled
|
test_image_embeddings.shape,
|
||||||
|
text_cond_shuffled,
|
||||||
|
timesteps=timesteps,
|
||||||
)
|
)
|
||||||
|
|
||||||
predicted_unrelated_embeddings = (
|
predicted_unrelated_embeddings = (
|
||||||
@@ -183,32 +330,97 @@ def report_cosine_sims(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# calculate similarities
|
# calculate similarities
|
||||||
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
|
orig_sim = pad_gather_reduce(
|
||||||
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
|
trainer, cos(text_embed, test_image_embeddings), method="mean"
|
||||||
unrelated_similarity = (
|
|
||||||
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
|
||||||
)
|
)
|
||||||
predicted_img_similarity = (
|
pred_sim = pad_gather_reduce(
|
||||||
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
trainer, cos(text_embed, predicted_image_embeddings), method="mean"
|
||||||
|
)
|
||||||
|
unrel_sim = pad_gather_reduce(
|
||||||
|
trainer, cos(text_embed, predicted_unrelated_embeddings), method="mean"
|
||||||
|
)
|
||||||
|
pred_img_sim = pad_gather_reduce(
|
||||||
|
trainer,
|
||||||
|
cos(test_image_embeddings, predicted_image_embeddings),
|
||||||
|
method="mean",
|
||||||
)
|
)
|
||||||
|
|
||||||
stats = {
|
stats = {
|
||||||
f"{tracker_context}/baseline similarity": np.mean(original_similarity),
|
f"{tracker_folder}/baseline similarity [steps={timesteps}]": orig_sim,
|
||||||
f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
|
f"{tracker_folder}/similarity with text [steps={timesteps}]": pred_sim,
|
||||||
f"{tracker_context}/similarity with original image": np.mean(
|
f"{tracker_folder}/similarity with original image [steps={timesteps}]": pred_img_sim,
|
||||||
predicted_img_similarity
|
f"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]": unrel_sim,
|
||||||
),
|
f"{tracker_folder}/difference from baseline similarity [steps={timesteps}]": pred_sim
|
||||||
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
|
- orig_sim,
|
||||||
f"{tracker_context}/difference from baseline similarity": np.mean(
|
|
||||||
predicted_similarity - original_similarity
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v in stats.items():
|
tracker.log(stats, step=trainer.step.item() + 1)
|
||||||
trainer.print(f"{tracker_context}/{k}: {v}")
|
|
||||||
|
|
||||||
if exists(tracker):
|
|
||||||
tracker.log(stats, step=trainer.step.item() + 1)
|
def eval_model(
|
||||||
|
trainer: DiffusionPriorTrainer,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
text_conditioned: bool,
|
||||||
|
split: str,
|
||||||
|
tracker: Tracker,
|
||||||
|
use_ema: bool,
|
||||||
|
report_cosine: bool,
|
||||||
|
report_loss: bool,
|
||||||
|
timesteps: List[int],
|
||||||
|
loss_type: str = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run evaluation on a model and track metrics
|
||||||
|
|
||||||
|
returns: loss if requested
|
||||||
|
"""
|
||||||
|
trainer.eval()
|
||||||
|
|
||||||
|
use_ema = "ema" if use_ema else "online"
|
||||||
|
tracker_folder = f"metrics/{use_ema}-{split}"
|
||||||
|
|
||||||
|
# detemine if valid timesteps are passed
|
||||||
|
|
||||||
|
min_timesteps = trainer.accelerator.unwrap_model(
|
||||||
|
trainer.diffusion_prior
|
||||||
|
).sample_timesteps
|
||||||
|
max_timesteps = trainer.accelerator.unwrap_model(
|
||||||
|
trainer.diffusion_prior
|
||||||
|
).noise_scheduler.num_timesteps
|
||||||
|
|
||||||
|
assert all_between(
|
||||||
|
timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps
|
||||||
|
), f"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}"
|
||||||
|
|
||||||
|
# measure cosine metrics across various eta and timesteps
|
||||||
|
|
||||||
|
if report_cosine:
|
||||||
|
for timestep in timesteps:
|
||||||
|
report_cosine_sims(
|
||||||
|
trainer,
|
||||||
|
dataloader=dataloader,
|
||||||
|
text_conditioned=text_conditioned,
|
||||||
|
tracker=tracker,
|
||||||
|
split=split,
|
||||||
|
timesteps=timestep,
|
||||||
|
tracker_folder=tracker_folder,
|
||||||
|
)
|
||||||
|
|
||||||
|
# measure loss on a seperate split of data
|
||||||
|
|
||||||
|
if report_loss:
|
||||||
|
loss = report_validation_loss(
|
||||||
|
trainer=trainer,
|
||||||
|
dataloader=dataloader,
|
||||||
|
text_conditioned=text_conditioned,
|
||||||
|
use_ema=use_ema,
|
||||||
|
tracker=tracker,
|
||||||
|
split=split,
|
||||||
|
tracker_folder=tracker_folder,
|
||||||
|
loss_type=loss_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
# training script
|
# training script
|
||||||
@@ -216,182 +428,327 @@ def report_cosine_sims(
|
|||||||
|
|
||||||
def train(
|
def train(
|
||||||
trainer: DiffusionPriorTrainer,
|
trainer: DiffusionPriorTrainer,
|
||||||
|
tracker: Tracker,
|
||||||
train_loader: DataLoader,
|
train_loader: DataLoader,
|
||||||
eval_loader: DataLoader,
|
eval_loader: DataLoader,
|
||||||
test_loader: DataLoader,
|
test_loader: DataLoader,
|
||||||
config: DiffusionPriorTrainConfig,
|
config: DiffusionPriorTrainConfig,
|
||||||
):
|
):
|
||||||
# distributed tracking with wandb
|
# init timers
|
||||||
if trainer.accelerator.num_processes > 1:
|
save_timer = Timer() # when to save
|
||||||
os.environ["WANDB_START_METHOD"] = "thread"
|
samples_timer = Timer() # samples/sec
|
||||||
|
validation_profiler = Timer() # how long is validation taking
|
||||||
|
validation_countdown = Timer() # when to perform evalutation
|
||||||
|
|
||||||
tracker = wandb.init(
|
# keep track of best validation loss
|
||||||
name=f"RANK:{trainer.device}",
|
|
||||||
entity=config.tracker.wandb_entity,
|
|
||||||
project=config.tracker.wandb_project,
|
|
||||||
config=config.dict(),
|
|
||||||
group=GROUP,
|
|
||||||
)
|
|
||||||
|
|
||||||
# sync after tracker init
|
best_validation_loss = config.train.best_validation_loss
|
||||||
trainer.wait_for_everyone()
|
samples_seen = config.train.num_samples_seen
|
||||||
|
|
||||||
# init a timer
|
|
||||||
timer = Timer()
|
|
||||||
|
|
||||||
# do training
|
# do training
|
||||||
for img, txt in train_loader:
|
|
||||||
trainer.train()
|
|
||||||
current_step = trainer.step.item() + 1
|
|
||||||
|
|
||||||
# place data on device
|
start_epoch = config.train.current_epoch
|
||||||
img = img.to(trainer.device)
|
|
||||||
txt = txt.to(trainer.device)
|
|
||||||
|
|
||||||
# pass to model
|
for epoch in range(start_epoch, config.train.epochs):
|
||||||
loss = trainer(text=txt, image_embed=img)
|
# if we finished out an old epoch, reset the distribution to be a full epoch
|
||||||
|
tracker.log({"tracking/epoch": epoch}, step=trainer.step.item())
|
||||||
|
|
||||||
# display & log loss (will only print from main process)
|
if train_loader.dataset.get_start() > 0 and epoch == start_epoch+1:
|
||||||
trainer.print(f"Step {current_step}: Loss {loss}")
|
if trainer.accelerator.is_main_process:
|
||||||
|
click.secho(f"Finished resumed epoch...resetting dataloader.")
|
||||||
|
train_loader.dataset.set_start(0)
|
||||||
|
|
||||||
# perform backprop & apply EMA updates
|
for img, txt in train_loader:
|
||||||
trainer.update()
|
# setup things every step
|
||||||
|
|
||||||
# track samples/sec/rank
|
trainer.train()
|
||||||
samples_per_sec = img.shape[0] / timer.elapsed()
|
current_step = trainer.step.item()
|
||||||
|
samples_timer.reset()
|
||||||
|
|
||||||
# samples seen
|
# place data on device
|
||||||
samples_seen = (
|
|
||||||
config.data.batch_size * trainer.accelerator.num_processes * current_step
|
|
||||||
)
|
|
||||||
|
|
||||||
# ema decay
|
img = img.to(trainer.device)
|
||||||
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
|
txt = txt.to(trainer.device)
|
||||||
|
|
||||||
# Log on all processes for debugging
|
# pass to model
|
||||||
tracker.log(
|
|
||||||
{
|
|
||||||
"tracking/samples-sec": samples_per_sec,
|
|
||||||
"tracking/samples-seen": samples_seen,
|
|
||||||
"tracking/ema-decay": ema_decay,
|
|
||||||
"metrics/training-loss": loss,
|
|
||||||
},
|
|
||||||
step=current_step,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Metric Tracking & Checkpointing (outside of timer's scope)
|
loss = trainer(text=txt, image_embed=img)
|
||||||
if current_step % EVAL_EVERY == 0:
|
|
||||||
eval_model(
|
# perform backprop & apply EMA updates
|
||||||
trainer=trainer,
|
|
||||||
dataloader=eval_loader,
|
trainer.update()
|
||||||
text_conditioned=config.prior.condition_on_text_encodings,
|
|
||||||
loss_type=config.prior.loss_type,
|
# gather info about training step
|
||||||
tracker_context="metrics/online-model-validation",
|
|
||||||
tracker=tracker,
|
all_loss = pad_gather_reduce(trainer, loss, method="mean")
|
||||||
use_ema=False,
|
num_samples = pad_gather_reduce(trainer, len(txt), method="sum")
|
||||||
|
samples_per_sec = num_samples / samples_timer.elapsed()
|
||||||
|
samples_seen += num_samples
|
||||||
|
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
|
||||||
|
|
||||||
|
# log
|
||||||
|
|
||||||
|
tracker.log(
|
||||||
|
{
|
||||||
|
"tracking/samples-sec": samples_per_sec,
|
||||||
|
"tracking/samples-seen": samples_seen,
|
||||||
|
"tracking/ema-decay": ema_decay,
|
||||||
|
f"tracking/training-{config.prior.loss_type}": all_loss,
|
||||||
|
},
|
||||||
|
step=current_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_model(
|
# Metric Tracking @ Timed Intervals
|
||||||
trainer=trainer,
|
|
||||||
dataloader=eval_loader,
|
eval_delta = pad_gather_reduce(
|
||||||
text_conditioned=config.prior.condition_on_text_encodings,
|
trainer, validation_countdown.elapsed(), method="min"
|
||||||
loss_type=config.prior.loss_type,
|
|
||||||
tracker_context="metrics/ema-model-validation",
|
|
||||||
tracker=tracker,
|
|
||||||
use_ema=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
report_cosine_sims(
|
if eval_delta != None and eval_delta > config.data.eval_every_seconds:
|
||||||
trainer=trainer,
|
# begin timing how long this takes
|
||||||
dataloader=eval_loader,
|
|
||||||
text_conditioned=config.prior.condition_on_text_encodings,
|
|
||||||
tracker=tracker,
|
|
||||||
tracker_context="metrics",
|
|
||||||
)
|
|
||||||
|
|
||||||
if current_step % config.train.save_every == 0:
|
validation_profiler.reset()
|
||||||
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
|
|
||||||
|
|
||||||
# reset timer for next round
|
# package kwargs for evaluation
|
||||||
timer.reset()
|
|
||||||
|
eval_kwargs = {
|
||||||
|
"trainer": trainer,
|
||||||
|
"tracker": tracker,
|
||||||
|
"text_conditioned": config.prior.condition_on_text_encodings,
|
||||||
|
"timesteps": config.train.eval_timesteps,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ONLINE MODEL : COSINE : LOSS : VALIDATION SPLIT
|
||||||
|
|
||||||
|
eval_model(
|
||||||
|
dataloader=eval_loader,
|
||||||
|
loss_type=config.prior.loss_type,
|
||||||
|
split="validation",
|
||||||
|
use_ema=False,
|
||||||
|
report_cosine=False,
|
||||||
|
report_loss=True,
|
||||||
|
**eval_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# EMA MODEL : COSINE : LOSS : VALIDATION DATA
|
||||||
|
|
||||||
|
ema_val_loss = eval_model(
|
||||||
|
dataloader=eval_loader,
|
||||||
|
loss_type=config.prior.loss_type,
|
||||||
|
split="validation",
|
||||||
|
use_ema=True,
|
||||||
|
report_cosine=True,
|
||||||
|
report_loss=True,
|
||||||
|
**eval_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker.log(
|
||||||
|
{
|
||||||
|
"tracking/validation length (minutes)": validation_profiler.elapsed()
|
||||||
|
/ 60
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if the ema validation is the lowest seen yet
|
||||||
|
|
||||||
|
if ema_val_loss < best_validation_loss:
|
||||||
|
best_validation_loss = ema_val_loss
|
||||||
|
|
||||||
|
# go save the model as best
|
||||||
|
|
||||||
|
save_trainer(
|
||||||
|
trainer=trainer,
|
||||||
|
tracker=tracker,
|
||||||
|
is_best=True,
|
||||||
|
is_latest=False,
|
||||||
|
samples_seen=samples_seen,
|
||||||
|
epoch=epoch,
|
||||||
|
best_validation_loss=best_validation_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
# reset timer for validaiton
|
||||||
|
|
||||||
|
validation_countdown.reset()
|
||||||
|
|
||||||
|
elif eval_delta is None:
|
||||||
|
click.secho(
|
||||||
|
f"Error occured reading the eval time on rank: {trainer.device}",
|
||||||
|
fg="yellow",
|
||||||
|
)
|
||||||
|
|
||||||
|
# save as latest model on schedule
|
||||||
|
|
||||||
|
save_delta = pad_gather_reduce(trainer, save_timer.elapsed(), method="min")
|
||||||
|
|
||||||
|
if save_delta != None and save_delta >= config.train.save_every_seconds:
|
||||||
|
save_trainer(
|
||||||
|
trainer=trainer,
|
||||||
|
tracker=tracker,
|
||||||
|
is_best=False,
|
||||||
|
is_latest=True,
|
||||||
|
samples_seen=samples_seen,
|
||||||
|
epoch=epoch,
|
||||||
|
best_validation_loss=best_validation_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_timer.reset()
|
||||||
|
|
||||||
|
elif save_delta is None:
|
||||||
|
click.secho(
|
||||||
|
f"Error occured reading the save time on rank: {trainer.device}",
|
||||||
|
fg="yellow",
|
||||||
|
)
|
||||||
|
|
||||||
# evaluate on test data
|
# evaluate on test data
|
||||||
|
|
||||||
eval_model(
|
if trainer.accelerator.is_main_process:
|
||||||
|
click.secho(f"Starting Test", fg="red")
|
||||||
|
|
||||||
|
# save one last time as latest before beginning validation
|
||||||
|
|
||||||
|
save_trainer(
|
||||||
|
tracker=tracker,
|
||||||
|
trainer=trainer,
|
||||||
|
is_best=False,
|
||||||
|
is_latest=True,
|
||||||
|
samples_seen=samples_seen,
|
||||||
|
epoch=epoch,
|
||||||
|
best_validation_loss=best_validation_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_loss = eval_model(
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
dataloader=test_loader,
|
dataloader=test_loader,
|
||||||
text_conditioned=config.prior.condition_on_text_encodings,
|
text_conditioned=config.prior.condition_on_text_encodings,
|
||||||
loss_type=config.prior.loss_type,
|
split="test",
|
||||||
tracker_context="test",
|
|
||||||
tracker=tracker,
|
tracker=tracker,
|
||||||
|
use_ema=True,
|
||||||
|
report_cosine=False,
|
||||||
|
report_loss=True,
|
||||||
|
timesteps=config.train.eval_timesteps,
|
||||||
|
loss_type=config.prior.loss_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
report_cosine_sims(
|
if test_loss < best_validation_loss:
|
||||||
trainer,
|
best_validation_loss = test_loss
|
||||||
test_loader,
|
|
||||||
config.prior.condition_on_text_encodings,
|
# go save the model as best
|
||||||
tracker,
|
|
||||||
tracker_context="test",
|
save_trainer(
|
||||||
)
|
trainer=trainer,
|
||||||
|
tracker=tracker,
|
||||||
|
is_best=True,
|
||||||
|
is_latest=False,
|
||||||
|
samples_seen=samples_seen,
|
||||||
|
epoch=epoch,
|
||||||
|
best_validation_loss=test_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def initialize_training(config, accelerator=None):
|
def initialize_training(config_file, accelerator):
|
||||||
"""
|
"""
|
||||||
Parse the configuration file, and prepare everything necessary for training
|
Parse the configuration file, and prepare everything necessary for training
|
||||||
"""
|
"""
|
||||||
|
# load the configuration file
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
click.secho(f"Loading configuration from {config_file}", fg="green")
|
||||||
|
|
||||||
|
config = TrainDiffusionPriorConfig.from_json_path(config_file)
|
||||||
|
|
||||||
|
# seed
|
||||||
|
|
||||||
|
set_seed(config.train.random_seed)
|
||||||
|
|
||||||
# get a device
|
# get a device
|
||||||
|
|
||||||
if accelerator:
|
device = accelerator.device
|
||||||
device = accelerator.device
|
|
||||||
click.secho(f"Accelerating on: {device}", fg="yellow")
|
|
||||||
else:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
|
|
||||||
device = "cuda:0"
|
|
||||||
else:
|
|
||||||
click.secho("No GPU detected...using cpu", fg="yellow")
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
# make the trainer (will automatically distribute if possible & configured)
|
# make the trainer (will automatically distribute if possible & configured)
|
||||||
|
|
||||||
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
|
trainer: DiffusionPriorTrainer = make_model(
|
||||||
|
config.prior, config.train, device, accelerator
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# create a tracker
|
||||||
|
|
||||||
|
tracker = create_tracker(
|
||||||
|
accelerator, config, config_file, dummy=accelerator.process_index != 0
|
||||||
|
)
|
||||||
|
|
||||||
# reload from chcekpoint
|
# reload from chcekpoint
|
||||||
|
|
||||||
if config.load.resume == True:
|
if tracker.can_recall:
|
||||||
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
|
current_epoch, best_validation_loss, samples_seen = recall_trainer(
|
||||||
trainer.load(config.load.source)
|
tracker=tracker, trainer=trainer
|
||||||
|
)
|
||||||
|
|
||||||
|
# display best values
|
||||||
|
if trainer.accelerator.is_main_process:
|
||||||
|
click.secho(f"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}", fg="yellow")
|
||||||
|
|
||||||
|
# update config to reflect recalled values
|
||||||
|
config.train.num_samples_seen = samples_seen
|
||||||
|
config.train.current_epoch = current_epoch
|
||||||
|
config.train.best_validation_loss = best_validation_loss
|
||||||
|
|
||||||
# fetch and prepare data
|
# fetch and prepare data
|
||||||
|
|
||||||
if trainer.is_main_process():
|
if trainer.accelerator.is_main_process:
|
||||||
click.secho("Grabbing data from source", fg="blue", blink=True)
|
click.secho("Grabbing data...", fg="blue", blink=True)
|
||||||
|
|
||||||
|
trainer.accelerator.wait_for_everyone()
|
||||||
img_reader = get_reader(
|
img_reader = get_reader(
|
||||||
text_conditioned=trainer.text_conditioned,
|
text_conditioned=trainer.text_conditioned,
|
||||||
img_url=config.data.image_url,
|
img_url=config.data.image_url,
|
||||||
meta_url=config.data.meta_url,
|
meta_url=config.data.meta_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# calculate start point within epoch
|
||||||
|
|
||||||
|
trainer.accelerator.wait_for_everyone()
|
||||||
|
|
||||||
train_loader, eval_loader, test_loader = make_splits(
|
train_loader, eval_loader, test_loader = make_splits(
|
||||||
text_conditioned=trainer.text_conditioned,
|
text_conditioned=trainer.text_conditioned,
|
||||||
batch_size=config.data.batch_size,
|
batch_size=config.data.batch_size,
|
||||||
num_data_points=NUM_DATA_POINTS,
|
num_data_points=config.data.num_data_points,
|
||||||
train_split=config.data.splits.train,
|
train_split=config.data.splits.train,
|
||||||
eval_split=config.data.splits.val,
|
eval_split=config.data.splits.val,
|
||||||
image_reader=img_reader,
|
image_reader=img_reader,
|
||||||
rank=accelerator.state.process_index if exists(accelerator) else 0,
|
rank=accelerator.state.process_index,
|
||||||
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
|
world_size=accelerator.state.num_processes,
|
||||||
start=START,
|
start=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# wait for everyone to load data before continuing
|
# update the start point to finish out the epoch on a resumed run
|
||||||
trainer.wait_for_everyone()
|
|
||||||
|
if tracker.can_recall:
|
||||||
|
samples_seen = config.train.num_samples_seen
|
||||||
|
length = (
|
||||||
|
config.data.num_data_points
|
||||||
|
if samples_seen <= img_reader.count
|
||||||
|
else img_reader.count
|
||||||
|
)
|
||||||
|
scaled_samples = length * config.train.current_epoch
|
||||||
|
start_point = (
|
||||||
|
scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen
|
||||||
|
)
|
||||||
|
|
||||||
|
if trainer.accelerator.is_main_process:
|
||||||
|
click.secho(f"Resuming at sample: {start_point}", fg="yellow")
|
||||||
|
|
||||||
|
train_loader.dataset.set_start(start_point)
|
||||||
|
|
||||||
# start training
|
# start training
|
||||||
|
|
||||||
|
if trainer.accelerator.is_main_process:
|
||||||
|
click.secho(
|
||||||
|
f"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}",
|
||||||
|
fg="yellow",
|
||||||
|
)
|
||||||
|
|
||||||
train(
|
train(
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
|
tracker=tracker,
|
||||||
train_loader=train_loader,
|
train_loader=train_loader,
|
||||||
eval_loader=eval_loader,
|
eval_loader=eval_loader,
|
||||||
test_loader=test_loader,
|
test_loader=test_loader,
|
||||||
@@ -400,23 +757,13 @@ def initialize_training(config, accelerator=None):
|
|||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option("--hfa", default=True)
|
@click.option("--config_file", default="configs/train_prior_config.example.json")
|
||||||
@click.option("--config_path", default="configs/prior.json")
|
def main(config_file):
|
||||||
def main(hfa, config_path):
|
# start HFA
|
||||||
# start HFA if requested
|
accelerator = Accelerator()
|
||||||
if hfa:
|
|
||||||
accelerator = Accelerator()
|
|
||||||
else:
|
|
||||||
accelerator = None
|
|
||||||
|
|
||||||
# load the configuration file on main process
|
# setup training
|
||||||
if not exists(accelerator) or accelerator.is_main_process:
|
initialize_training(config_file, accelerator)
|
||||||
click.secho(f"Loading configuration from {config_path}", fg="green")
|
|
||||||
|
|
||||||
config = TrainDiffusionPriorConfig.from_json_path(config_path)
|
|
||||||
|
|
||||||
# send config to get processed
|
|
||||||
initialize_training(config, accelerator)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user