Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
6520d17215 fix ddim to use alpha_cumprod 2022-08-30 20:35:08 -07:00
9 changed files with 94 additions and 221 deletions

View File

@@ -634,12 +634,10 @@ Alternatively, you can also use <a href="https://github.com/mlfoundations/open_c
$ pip install open-clip-torch
```
Ex. using the <a href="https://laion.ai/blog/large-openclip/">SOTA Open Clip</a> model trained by <a href="https://github.com/rom1504">Romain</a>
```python
from dalle2_pytorch import OpenClipAdapter
clip = OpenClipAdapter('ViT-H/14')
clip = OpenClipAdapter()
```
Now you'll just have to worry about training the Prior and the Decoder!
@@ -1068,7 +1066,7 @@ dataloader = create_image_embedding_dataloader(
)
for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb["img"].shape) # torch.Size([32, 512])
print(emb.shape) # torch.Size([32, 512])
# Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually
@@ -1128,7 +1126,6 @@ For detailed information on training the diffusion prior, please refer to the [d
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
- [ ] add simple outpainting, text-guided 2x size the image for starters
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
## Citations
@@ -1298,14 +1295,4 @@ For detailed information on training the diffusion prior, please refer to the [d
}
```
```bibtex
@article{Salimans2022ProgressiveDF,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.00512}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE

View File

@@ -12,8 +12,10 @@ from torch.utils.checkpoint import checkpoint
from torch import nn, einsum
import torchvision.transforms as T
from einops import rearrange, repeat, reduce, pack, unpack
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters import gaussian_blur2d
import kornia.augmentation as K
@@ -98,9 +100,6 @@ def eval_decorator(fn):
return out
return inner
def is_float_dtype(dtype):
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
def is_list_str(x):
if not isinstance(x, (list, tuple)):
return False
@@ -315,10 +314,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
self.eos_id = 49407 # for handling 0 being also '!'
text_attention_final = self.find_layer('ln_final')
self.dim_latent_ = text_attention_final.weight.shape[0]
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
@@ -337,7 +333,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
@property
def dim_latent(self):
return self.dim_latent_
return 512
@property
def image_size(self):
@@ -358,7 +354,6 @@ class OpenAIClipAdapter(BaseClipAdapter):
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared
text_embed = self.clip.encode_text(text)
@@ -388,8 +383,6 @@ class OpenClipAdapter(BaseClipAdapter):
self.eos_id = 49407
text_attention_final = self.find_layer('ln_final')
self._dim_latent = text_attention_final.weight.shape[0]
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
@@ -409,14 +402,11 @@ class OpenClipAdapter(BaseClipAdapter):
@property
def dim_latent(self):
return self._dim_latent
return 512
@property
def image_size(self):
image_size = self.clip.visual.image_size
if isinstance(image_size, tuple):
return max(image_size)
return image_size
return self.clip.visual.image_size
@property
def image_channels(self):
@@ -433,7 +423,6 @@ class OpenClipAdapter(BaseClipAdapter):
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared
text_embed = self.clip.encode_text(text)
@@ -619,7 +608,7 @@ class NoiseScheduler(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def q_sample(self, x_start, t, noise = None):
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
@@ -627,12 +616,6 @@ class NoiseScheduler(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def calculate_v(self, x_start, t, noise = None):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
shape = x_from.shape
noise = default(noise, lambda: torch.randn_like(x_from))
@@ -644,12 +627,6 @@ class NoiseScheduler(nn.Module):
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
def predict_start_from_v(self, x_t, t, v):
return (
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
@@ -667,23 +644,6 @@ class NoiseScheduler(nn.Module):
return loss
return loss * extract(self.p2_loss_weight, times, loss.shape)
# rearrange image to sequence
class RearrangeToSequence(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
x = rearrange(x, 'b c ... -> b ... c')
x, ps = pack([x], 'b * c')
x = self.fn(x)
x, = unpack(x, ps, 'b * c')
x = rearrange(x, 'b ... c -> b c ...')
return x
# diffusion prior
class LayerNorm(nn.Module):
@@ -882,7 +842,7 @@ class Attention(nn.Module):
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
@@ -1002,8 +962,6 @@ class DiffusionPriorNetwork(nn.Module):
Rearrange('b (n d) -> b n d', n = num_text_embeds)
)
self.continuous_embedded_time = not exists(num_timesteps)
self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds)
@@ -1112,7 +1070,7 @@ class DiffusionPriorNetwork(nn.Module):
null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
text_embed = torch.where(
text_embeds = torch.where(
text_keep_mask,
text_embed,
null_text_embeds
@@ -1131,15 +1089,12 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if self.continuous_embedded_time:
diffusion_timesteps = diffusion_timesteps.type(dtype)
time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond:
learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
tokens = torch.cat((
text_encodings,
@@ -1175,7 +1130,6 @@ class DiffusionPrior(nn.Module):
image_cond_drop_prob = None,
loss_type = "l2",
predict_x_start = True,
predict_v = False,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
@@ -1227,7 +1181,6 @@ class DiffusionPrior(nn.Module):
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
self.predict_x_start = predict_x_start
self.predict_v = predict_v # takes precedence over predict_x_start
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
@@ -1257,9 +1210,7 @@ class DiffusionPrior(nn.Module):
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
if self.predict_v:
x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)
elif self.predict_x_start:
if self.predict_x_start:
x_start = pred
else:
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -1308,7 +1259,7 @@ class DiffusionPrior(nn.Module):
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1]
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
@@ -1330,16 +1281,12 @@ class DiffusionPrior(nn.Module):
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
# derive x0
if self.predict_v:
x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)
elif self.predict_x_start:
if self.predict_x_start:
x_start = pred
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
# clip x0 before maybe predicting noise
pred_noise = pred
if not self.predict_x_start:
x_start.clamp_(-1., 1.)
@@ -1347,14 +1294,6 @@ class DiffusionPrior(nn.Module):
if self.predict_x_start and self.sampling_clamp_l2norm:
x_start = self.l2norm_clamp_embed(x_start)
# predict noise
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
if time_next < 0:
image_embed = x_start
continue
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
@@ -1404,12 +1343,7 @@ class DiffusionPrior(nn.Module):
if self.predict_x_start and self.training_clamp_l2norm:
pred = self.l2norm_clamp_embed(pred)
if self.predict_v:
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
elif self.predict_x_start:
target = image_embed
else:
target = noise
target = noise if not self.predict_x_start else image_embed
loss = self.noise_scheduler.loss_fn(pred, target)
return loss
@@ -1479,7 +1413,7 @@ class DiffusionPrior(nn.Module):
**kwargs
):
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
if exists(image):
@@ -1585,8 +1519,6 @@ class SinusoidalPosEmb(nn.Module):
def forward(self, x):
dtype, device = x.dtype, x.device
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
@@ -1644,10 +1576,14 @@ class ResnetBlock(nn.Module):
self.cross_attn = None
if exists(cond_dim):
self.cross_attn = CrossAttention(
dim = dim_out,
context_dim = cond_dim,
cosine_sim = cosine_sim_cross_attn
self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim_out,
context_dim = cond_dim,
cosine_sim = cosine_sim_cross_attn
)
)
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
@@ -1666,15 +1602,8 @@ class ResnetBlock(nn.Module):
if exists(self.cross_attn):
assert exists(cond)
h = rearrange(h, 'b c ... -> b ... c')
h, ps = pack([h], 'b * c')
h = self.cross_attn(h, context = cond) + h
h, = unpack(h, ps, 'b * c')
h = rearrange(h, 'b ... c -> b c ...')
h = self.block2(h)
return h + self.res_conv(x)
@@ -1720,11 +1649,11 @@ class CrossAttention(nn.Module):
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
@@ -1777,7 +1706,7 @@ class LinearAttention(nn.Module):
fmap = self.norm(fmap)
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
@@ -2011,7 +1940,7 @@ class Unet(nn.Module):
self_attn = cast_tuple(self_attn, num_stages)
create_self_attn = lambda dim: RearrangeToSequence(Residual(Attention(dim, **attn_kwargs)))
create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs)))
# resnet block klass
@@ -2488,7 +2417,6 @@ class Decoder(nn.Module):
loss_type = 'l2',
beta_schedule = None,
predict_x_start = False,
predict_v = False,
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
@@ -2511,7 +2439,7 @@ class Decoder(nn.Module):
dynamic_thres_percentile = 0.95,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1,
ddim_sampling_eta = 0. # can be set to 0. for deterministic sampling afaict
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
):
super().__init__()
@@ -2661,10 +2589,6 @@ class Decoder(nn.Module):
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
# predict v
self.predict_v = cast_tuple(predict_v, len(unets))
# input image range
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
@@ -2745,16 +2669,11 @@ class Decoder(nn.Module):
if exists(unet_number):
unet = self.get_unet(unet_number)
# devices
cuda, cpu = torch.device('cuda'), torch.device('cpu')
self.cuda()
devices = [module_device(unet) for unet in self.unets]
self.unets.to(cpu)
unet.to(cuda)
self.unets.cpu()
unet.cuda()
yield
@@ -2781,16 +2700,14 @@ class Decoder(nn.Module):
x = x.clamp(-s, s) / s
return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
if predict_v:
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
elif predict_x_start:
if predict_x_start:
x_start = pred
else:
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -2817,9 +2734,9 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
@@ -2834,7 +2751,6 @@ class Decoder(nn.Module):
image_embed,
noise_scheduler,
predict_x_start = False,
predict_v = False,
learned_variance = False,
clip_denoised = True,
lowres_cond_img = None,
@@ -2893,7 +2809,6 @@ class Decoder(nn.Module):
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start,
predict_v = predict_v,
noise_scheduler = noise_scheduler,
learned_variance = learned_variance,
clip_denoised = clip_denoised
@@ -2919,7 +2834,6 @@ class Decoder(nn.Module):
timesteps,
eta = 1.,
predict_x_start = False,
predict_v = False,
learned_variance = False,
clip_denoised = True,
lowres_cond_img = None,
@@ -2979,24 +2893,16 @@ class Decoder(nn.Module):
pred, _ = self.parse_unet_output(learned_variance, unet_output)
# predict x0
if predict_v:
x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)
elif predict_x_start:
if predict_x_start:
x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
# maybe clip x0
pred_noise = pred
if clip_denoised:
x_start = self.dynamic_threshold(x_start)
# predict noise
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if not is_last_timestep else 0.
@@ -3029,7 +2935,7 @@ class Decoder(nn.Module):
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
@@ -3074,12 +2980,7 @@ class Decoder(nn.Module):
pred, _ = self.parse_unet_output(learned_variance, unet_output)
if predict_v:
target = noise_scheduler.calculate_v(x_start, times, noise)
elif predict_x_start:
target = x_start
else:
target = noise
target = noise if not predict_x_start else x_start
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean')
@@ -3137,8 +3038,7 @@ class Decoder(nn.Module):
distributed = False,
inpaint_image = None,
inpaint_mask = None,
inpaint_resample_times = 5,
one_unet_in_gpu_at_time = True
inpaint_resample_times = 5
):
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
@@ -3161,17 +3061,16 @@ class Decoder(nn.Module):
assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)
prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]
img = resize_image_to(image, prev_unet_output_size, nearest = True)
is_cuda = next(self.parameters()).is_cuda
num_unets = self.num_unets
cond_scale = cast_tuple(cond_scale, num_unets)
for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
if unet_number < start_at_unet_number:
continue # It's the easiest way to do it
context = self.one_unet_in_gpu(unet = unet) if is_cuda and one_unet_in_gpu_at_time else null_context()
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
with context:
# prepare low resolution conditioning for upsamplers
@@ -3203,7 +3102,6 @@ class Decoder(nn.Module):
text_encodings = text_encodings,
cond_scale = unet_cond_scale,
predict_x_start = predict_x_start,
predict_v = predict_v,
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img,
@@ -3243,12 +3141,11 @@ class Decoder(nn.Module):
lowres_conditioner = self.lowres_conds[unet_index]
target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
predict_v = self.predict_v[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
learned_variance = self.learned_variance[unet_index]
b, c, h, w, device, = *image.shape, image.device
assert image.shape[1] == self.channels
check_shape(image, 'b c h w', c = self.channels)
assert h >= target_image_size and w >= target_image_size
times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
@@ -3282,7 +3179,7 @@ class Decoder(nn.Module):
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
if not return_lowres_cond_image:
return losses

View File

@@ -1,16 +1,14 @@
import json
from torchvision import transforms as T
from pydantic import BaseModel, validator, model_validator
from pydantic import BaseModel, validator, root_validator
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
from x_clip import CLIP as XCLIP
from open_clip import list_pretrained
from coca_pytorch import CoCa
from dalle2_pytorch.dalle2_pytorch import (
CoCaAdapter,
OpenAIClipAdapter,
OpenClipAdapter,
Unet,
Decoder,
DiffusionPrior,
@@ -38,12 +36,12 @@ class TrainSplitConfig(BaseModel):
val: float = 0.15
test: float = 0.1
@model_validator(mode = 'after')
def validate_all(self, m):
actual_sum = sum([*dict(self).values()])
@root_validator
def validate_all(cls, fields):
actual_sum = sum([*fields.values()])
if actual_sum != 1.:
raise ValueError(f'{dict(self).keys()} must sum to 1.0. Found: {actual_sum}')
return self
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
return fields
class TrackerLogConfig(BaseModel):
log_type: str = 'console'
@@ -59,7 +57,6 @@ class TrackerLogConfig(BaseModel):
kwargs = self.dict()
return create_logger(self.log_type, data_path, **kwargs)
class TrackerLoadConfig(BaseModel):
load_from: Optional[str] = None
only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
@@ -90,7 +87,7 @@ class TrackerConfig(BaseModel):
data_path: str = '.tracker_data'
overwrite_data_path: bool = False
log: TrackerLogConfig
load: Optional[TrackerLoadConfig] = None
load: Optional[TrackerLoadConfig]
save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
@@ -115,15 +112,11 @@ class TrackerConfig(BaseModel):
class AdapterConfig(BaseModel):
make: str = "openai"
model: str = "ViT-L/14"
base_model_kwargs: Optional[Dict[str, Any]] = None
base_model_kwargs: Dict[str, Any] = None
def create(self):
if self.make == "openai":
return OpenAIClipAdapter(self.model)
elif self.make == "open_clip":
pretrained = dict(list_pretrained())
checkpoint = pretrained[self.model]
return OpenClipAdapter(name=self.model, pretrained=checkpoint)
elif self.make == "x-clip":
return XClipAdapter(XCLIP(**self.base_model_kwargs))
elif self.make == "coca":
@@ -134,8 +127,8 @@ class AdapterConfig(BaseModel):
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
max_text_len: Optional[int] = None
num_timesteps: Optional[int] = None
max_text_len: int = None
num_timesteps: int = None
num_time_embeds: int = 1
num_image_embeds: int = 1
num_text_embeds: int = 1
@@ -158,7 +151,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
return DiffusionPriorNetwork(**kwargs)
class DiffusionPriorConfig(BaseModel):
clip: Optional[AdapterConfig] = None
clip: AdapterConfig = None
net: DiffusionPriorNetworkConfig
image_embed_dim: int
image_size: int
@@ -195,7 +188,7 @@ class DiffusionPriorTrainConfig(BaseModel):
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
warmup_steps: Optional[int] = None # number of warmup steps
warmup_steps: int = None # number of warmup steps
save_every_seconds: int = 3600 # how often to save
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
best_validation_loss: float = 1e9 # the current best valudation loss observed
@@ -228,10 +221,10 @@ class TrainDiffusionPriorConfig(BaseModel):
class UnetConfig(BaseModel):
dim: int
dim_mults: ListOrTuple[int]
image_embed_dim: Optional[int] = None
text_embed_dim: Optional[int] = None
cond_on_text_encodings: Optional[bool] = None
cond_dim: Optional[int] = None
image_embed_dim: int = None
text_embed_dim: int = None
cond_on_text_encodings: bool = None
cond_dim: int = None
channels: int = 3
self_attn: ListOrTuple[int]
attn_dim_head: int = 32
@@ -243,14 +236,14 @@ class UnetConfig(BaseModel):
class DecoderConfig(BaseModel):
unets: ListOrTuple[UnetConfig]
image_size: Optional[int] = None
image_size: int = None
image_sizes: ListOrTuple[int] = None
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
loss_type: str = 'l2'
beta_schedule: Optional[ListOrTuple[str]] = None # None means all cosine
beta_schedule: ListOrTuple[str] = None # None means all cosine
learned_variance: SingularOrIterable[bool] = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5
@@ -278,9 +271,9 @@ class DecoderConfig(BaseModel):
extra = "allow"
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
img_embeddings_url: Optional[str] = None # path to .npy files with embeddings
text_embeddings_url: Optional[str] = None # path to .npy files with embeddings
webdataset_base_url: str # path to a webdataset with jpg images
img_embeddings_url: Optional[str] # path to .npy files with embeddings
text_embeddings_url: Optional[str] # path to .npy files with embeddings
num_workers: int = 4
batch_size: int = 64
start_shard: int = 0
@@ -314,26 +307,25 @@ class DecoderTrainConfig(BaseModel):
wd: SingularOrIterable[float] = 0.01
warmup_steps: Optional[SingularOrIterable[int]] = None
find_unused_parameters: bool = True
static_graph: bool = True
max_grad_norm: SingularOrIterable[float] = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
cond_scale: Union[float, List[float]] = 1.0
device: str = 'cuda:0'
epoch_samples: Optional[int] = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: Optional[int] = None # Same as above but for validation.
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
save_immediately: bool = False
use_ema: bool = True
ema_beta: float = 0.999
amp: bool = False
unet_training_mask: Optional[ListOrTuple[bool]] = None # If None, use all unets
unet_training_mask: ListOrTuple[bool] = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
FID: Optional[Dict[str, Any]] = None
IS: Optional[Dict[str, Any]] = None
KID: Optional[Dict[str, Any]] = None
LPIPS: Optional[Dict[str, Any]] = None
FID: Dict[str, Any] = None
IS: Dict[str, Any] = None
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
class TrainDecoderConfig(BaseModel):
decoder: DecoderConfig
@@ -347,14 +339,11 @@ class TrainDecoderConfig(BaseModel):
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
print(config)
return cls(**config)
@model_validator(mode = 'after')
def check_has_embeddings(self, m):
@root_validator
def check_has_embeddings(cls, values):
# Makes sure that enough information is provided to get the embeddings specified for training
values = dict(self)
data_config, decoder_config = values.get('data'), values.get('decoder')
if not exists(data_config) or not exists(decoder_config):
@@ -379,4 +368,4 @@ class TrainDecoderConfig(BaseModel):
if text_emb_url:
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
return m
return values

View File

@@ -236,7 +236,7 @@ class DiffusionPriorTrainer(nn.Module):
)
if exists(cosine_decay_max_steps):
self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)

View File

@@ -1 +1 @@
__version__ = '1.15.4'
__version__ = '1.10.2'

View File

@@ -11,7 +11,8 @@ import torch.nn.functional as F
from torch.autograd import grad as torch_grad
import torchvision
from einops import rearrange, reduce, repeat, pack, unpack
from einops import rearrange, reduce, repeat
from einops_exts import rearrange_many
from einops.layers.torch import Rearrange
# constants
@@ -407,7 +408,7 @@ class Attention(nn.Module):
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)

View File

@@ -26,17 +26,17 @@ setup(
install_requires=[
'accelerate',
'click',
'open-clip-torch>=2.0.0,<3.0.0',
'clip-anytorch>=2.5.2',
'clip-anytorch>=2.4.0',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7',
'einops>=0.7.0',
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'numpy',
'packaging',
'pillow',
'pydantic>=2',
'pydantic',
'pytorch-warmup',
'resize-right>=0.0.2',
'rotary-embedding-torch',

View File

@@ -156,7 +156,7 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
if text_embeddings[0] is None:
# Generate text embeddings from text
assert clip is not None, "clip is None, but text_embeddings is None"
tokenized_texts = tokenize(txts, truncate=True).to(device=device)
tokenized_texts = tokenize(txts, truncate=True)
text_embed, text_encodings = clip.embed_text(tokenized_texts)
sample_params["text_encodings"] = text_encodings
else:
@@ -229,8 +229,8 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=Non
metrics["KID_std"] = kid_std.item()
if exists(LPIPS):
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images)
@@ -480,7 +480,7 @@ def train(
else:
# Then we need to pass the text instead
assert clip is not None
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
text_embed, text_encodings = clip.embed_text(tokenized_texts)
forward_params['text_encodings'] = text_encodings
@@ -556,7 +556,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
@@ -577,7 +577,6 @@ def initialize_training(config: TrainDecoderConfig, config_path):
shards_per_process = len(all_shards) // world_size
assert shards_per_process > 0, "Not enough shards to split evenly"
my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]
dataloaders = create_dataloaders (
available_shards=my_shards,
img_preproc = config.data.img_preproc,