Compare commits

..

11 Commits

Author SHA1 Message Date
Phil Wang
cc58f75474 bump to newer package of clip-anytorch that allows for text encodings < maximum context length 2023-03-04 09:37:25 -08:00
Phil Wang
3b2cf7b0bc fix for self conditioning in diffusion prior network https://github.com/lucidrains/DALLE2-pytorch/issues/273 2023-02-11 17:18:40 -08:00
Phil Wang
984d62a373 default ddim sampling eta to 0 2022-12-23 13:23:09 -08:00
Phil Wang
683dd98b96 extra insurance in case eos id is not there 2022-12-15 10:54:21 -08:00
Phil Wang
067ac323da address https://github.com/lucidrains/DALLE2-pytorch/issues/266 2022-11-23 08:41:25 -08:00
zion
91c8d1ca13 bug fix cosine annealing optimizer in prior trainer (#262) 2022-11-11 12:15:13 -08:00
zion
08238a7200 depend on open-clip-torch (#261)
fix the previous commit which assumes open_clip is installed
2022-11-07 16:19:08 -08:00
zion
7166ad6711 add open clip to train_config (#260)
add the ability to use open_clip in the train configs (useful for the new SOTA h/14 model)
2022-11-07 15:44:36 -08:00
Phil Wang
fbba0f9aaf bring in prediction of v objective, combining the findings from progressive distillation paper and imagen-video to the eventual extension of dalle2 to make-a-video 2022-10-28 18:21:07 -07:00
Romain Beaumont
9f37705d87 Add static graph param (#226)
* Add static graph param

* use static graph param
2022-10-25 19:31:29 +02:00
Phil Wang
c3df46e374 fix openclipadapter to be able to use latest open sourced sota model 2022-10-23 15:12:09 -07:00
8 changed files with 89 additions and 25 deletions

View File

@@ -1298,4 +1298,14 @@ For detailed information on training the diffusion prior, please refer to the [d
} }
``` ```
```bibtex
@article{Salimans2022ProgressiveDF,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.00512}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.version import __version__ from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE from dalle2_pytorch.vqgan_vae import VQGanVAE

View File

@@ -360,6 +360,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
is_eos_id = (text == self.eos_id) is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0 text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True) text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared assert not self.cleared
text_embed = self.clip.encode_text(text) text_embed = self.clip.encode_text(text)
@@ -389,6 +390,8 @@ class OpenClipAdapter(BaseClipAdapter):
self.eos_id = 49407 self.eos_id = 49407
text_attention_final = self.find_layer('ln_final') text_attention_final = self.find_layer('ln_final')
self._dim_latent = text_attention_final.weight.shape[0]
self.handle = text_attention_final.register_forward_hook(self._hook) self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1] self.clip_normalize = preprocess.transforms[-1]
self.cleared = False self.cleared = False
@@ -408,7 +411,7 @@ class OpenClipAdapter(BaseClipAdapter):
@property @property
def dim_latent(self): def dim_latent(self):
return 512 return self._dim_latent
@property @property
def image_size(self): def image_size(self):
@@ -432,6 +435,7 @@ class OpenClipAdapter(BaseClipAdapter):
is_eos_id = (text == self.eos_id) is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0 text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True) text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared assert not self.cleared
text_embed = self.clip.encode_text(text) text_embed = self.clip.encode_text(text)
@@ -617,7 +621,7 @@ class NoiseScheduler(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
return ( return (
@@ -625,6 +629,12 @@ class NoiseScheduler(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
) )
def calculate_v(self, x_start, t, noise = None):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
def q_sample_from_to(self, x_from, from_t, to_t, noise = None): def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
shape = x_from.shape shape = x_from.shape
noise = default(noise, lambda: torch.randn_like(x_from)) noise = default(noise, lambda: torch.randn_like(x_from))
@@ -636,6 +646,12 @@ class NoiseScheduler(nn.Module):
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
def predict_start_from_v(self, x_t, t, v):
return (
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_start_from_noise(self, x_t, t, noise): def predict_start_from_noise(self, x_t, t, noise):
return ( return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
@@ -1108,7 +1124,7 @@ class DiffusionPriorNetwork(nn.Module):
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond: if self.self_cond:
learned_queries = torch.cat((image_embed, self_cond), dim = -2) learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
tokens = torch.cat(( tokens = torch.cat((
text_encodings, text_encodings,
@@ -1144,6 +1160,7 @@ class DiffusionPrior(nn.Module):
image_cond_drop_prob = None, image_cond_drop_prob = None,
loss_type = "l2", loss_type = "l2",
predict_x_start = True, predict_x_start = True,
predict_v = False,
beta_schedule = "cosine", beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs) sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
@@ -1195,6 +1212,7 @@ class DiffusionPrior(nn.Module):
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
self.predict_x_start = predict_x_start self.predict_x_start = predict_x_start
self.predict_v = predict_v # takes precedence over predict_x_start
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 # @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
@@ -1224,7 +1242,9 @@ class DiffusionPrior(nn.Module):
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond) pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
if self.predict_x_start: if self.predict_v:
x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)
elif self.predict_x_start:
x_start = pred x_start = pred
else: else:
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -1297,10 +1317,12 @@ class DiffusionPrior(nn.Module):
# derive x0 # derive x0
if self.predict_x_start: if self.predict_v:
x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)
elif self.predict_x_start:
x_start = pred x_start = pred
else: else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise) x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
# clip x0 before maybe predicting noise # clip x0 before maybe predicting noise
@@ -1312,7 +1334,7 @@ class DiffusionPrior(nn.Module):
# predict noise # predict noise
if self.predict_x_start: if self.predict_x_start or self.predict_v:
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start) pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
else: else:
pred_noise = pred pred_noise = pred
@@ -1370,7 +1392,12 @@ class DiffusionPrior(nn.Module):
if self.predict_x_start and self.training_clamp_l2norm: if self.predict_x_start and self.training_clamp_l2norm:
pred = self.l2norm_clamp_embed(pred) pred = self.l2norm_clamp_embed(pred)
target = noise if not self.predict_x_start else image_embed if self.predict_v:
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
elif self.predict_x_start:
target = image_embed
else:
target = noise
loss = self.noise_scheduler.loss_fn(pred, target) loss = self.noise_scheduler.loss_fn(pred, target)
return loss return loss
@@ -2446,6 +2473,7 @@ class Decoder(nn.Module):
loss_type = 'l2', loss_type = 'l2',
beta_schedule = None, beta_schedule = None,
predict_x_start = False, predict_x_start = False,
predict_v = False,
predict_x_start_for_latent_diffusion = False, predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops) random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
@@ -2468,7 +2496,7 @@ class Decoder(nn.Module):
dynamic_thres_percentile = 0.95, dynamic_thres_percentile = 0.95,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1, p2_loss_weight_k = 1,
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict ddim_sampling_eta = 0. # can be set to 0. for deterministic sampling afaict
): ):
super().__init__() super().__init__()
@@ -2618,6 +2646,10 @@ class Decoder(nn.Module):
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes)) self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
# predict v
self.predict_v = cast_tuple(predict_v, len(unets))
# input image range # input image range
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.) self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
@@ -2729,14 +2761,16 @@ class Decoder(nn.Module):
x = x.clamp(-s, s) / s x = x.clamp(-s, s) / s
return x return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None): def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level)) model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output) pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
if predict_x_start: if predict_v:
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
elif predict_x_start:
x_start = pred x_start = pred
else: else:
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -2763,9 +2797,9 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance, x_start return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None): def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level) model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
noise = torch.randn_like(x) noise = torch.randn_like(x)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
@@ -2780,6 +2814,7 @@ class Decoder(nn.Module):
image_embed, image_embed,
noise_scheduler, noise_scheduler,
predict_x_start = False, predict_x_start = False,
predict_v = False,
learned_variance = False, learned_variance = False,
clip_denoised = True, clip_denoised = True,
lowres_cond_img = None, lowres_cond_img = None,
@@ -2838,6 +2873,7 @@ class Decoder(nn.Module):
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level, lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
predict_v = predict_v,
noise_scheduler = noise_scheduler, noise_scheduler = noise_scheduler,
learned_variance = learned_variance, learned_variance = learned_variance,
clip_denoised = clip_denoised clip_denoised = clip_denoised
@@ -2863,6 +2899,7 @@ class Decoder(nn.Module):
timesteps, timesteps,
eta = 1., eta = 1.,
predict_x_start = False, predict_x_start = False,
predict_v = False,
learned_variance = False, learned_variance = False,
clip_denoised = True, clip_denoised = True,
lowres_cond_img = None, lowres_cond_img = None,
@@ -2924,7 +2961,9 @@ class Decoder(nn.Module):
# predict x0 # predict x0
if predict_x_start: if predict_v:
x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)
elif predict_x_start:
x_start = pred x_start = pred
else: else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred) x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
@@ -2936,8 +2975,8 @@ class Decoder(nn.Module):
# predict noise # predict noise
if predict_x_start: if predict_x_start or predict_v:
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred) pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
else: else:
pred_noise = pred pred_noise = pred
@@ -2973,7 +3012,7 @@ class Decoder(nn.Module):
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs) return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None): def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1] # normalize to [-1, 1]
@@ -3018,7 +3057,12 @@ class Decoder(nn.Module):
pred, _ = self.parse_unet_output(learned_variance, unet_output) pred, _ = self.parse_unet_output(learned_variance, unet_output)
target = noise if not predict_x_start else x_start if predict_v:
target = noise_scheduler.calculate_v(x_start, times, noise)
elif predict_x_start:
target = x_start
else:
target = noise
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none') loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean') loss = reduce(loss, 'b ... -> b (...)', 'mean')
@@ -3104,7 +3148,7 @@ class Decoder(nn.Module):
num_unets = self.num_unets num_unets = self.num_unets
cond_scale = cast_tuple(cond_scale, num_unets) cond_scale = cast_tuple(cond_scale, num_unets)
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)): for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
if unet_number < start_at_unet_number: if unet_number < start_at_unet_number:
continue # It's the easiest way to do it continue # It's the easiest way to do it
@@ -3140,6 +3184,7 @@ class Decoder(nn.Module):
text_encodings = text_encodings, text_encodings = text_encodings,
cond_scale = unet_cond_scale, cond_scale = unet_cond_scale,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
predict_v = predict_v,
learned_variance = learned_variance, learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion, clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
@@ -3179,6 +3224,7 @@ class Decoder(nn.Module):
lowres_conditioner = self.lowres_conds[unet_index] lowres_conditioner = self.lowres_conds[unet_index]
target_image_size = self.image_sizes[unet_index] target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index] predict_x_start = self.predict_x_start[unet_index]
predict_v = self.predict_v[unet_index]
random_crop_size = self.random_crop_sizes[unet_index] random_crop_size = self.random_crop_sizes[unet_index]
learned_variance = self.learned_variance[unet_index] learned_variance = self.learned_variance[unet_index]
b, c, h, w, device, = *image.shape, image.device b, c, h, w, device, = *image.shape, image.device
@@ -3217,7 +3263,7 @@ class Decoder(nn.Module):
image = vae.encode(image) image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img) lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level) losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
if not return_lowres_cond_image: if not return_lowres_cond_image:
return losses return losses

View File

@@ -4,11 +4,13 @@ from pydantic import BaseModel, validator, root_validator
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
from x_clip import CLIP as XCLIP from x_clip import CLIP as XCLIP
from open_clip import list_pretrained
from coca_pytorch import CoCa from coca_pytorch import CoCa
from dalle2_pytorch.dalle2_pytorch import ( from dalle2_pytorch.dalle2_pytorch import (
CoCaAdapter, CoCaAdapter,
OpenAIClipAdapter, OpenAIClipAdapter,
OpenClipAdapter,
Unet, Unet,
Decoder, Decoder,
DiffusionPrior, DiffusionPrior,
@@ -117,6 +119,10 @@ class AdapterConfig(BaseModel):
def create(self): def create(self):
if self.make == "openai": if self.make == "openai":
return OpenAIClipAdapter(self.model) return OpenAIClipAdapter(self.model)
elif self.make == "open_clip":
pretrained = dict(list_pretrained())
checkpoint = pretrained[self.model]
return OpenClipAdapter(name=self.model, pretrained=checkpoint)
elif self.make == "x-clip": elif self.make == "x-clip":
return XClipAdapter(XCLIP(**self.base_model_kwargs)) return XClipAdapter(XCLIP(**self.base_model_kwargs))
elif self.make == "coca": elif self.make == "coca":
@@ -307,6 +313,7 @@ class DecoderTrainConfig(BaseModel):
wd: SingularOrIterable[float] = 0.01 wd: SingularOrIterable[float] = 0.01
warmup_steps: Optional[SingularOrIterable[int]] = None warmup_steps: Optional[SingularOrIterable[int]] = None
find_unused_parameters: bool = True find_unused_parameters: bool = True
static_graph: bool = True
max_grad_norm: SingularOrIterable[float] = 0.5 max_grad_norm: SingularOrIterable[float] = 0.5
save_every_n_samples: int = 100000 save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset

View File

@@ -236,7 +236,7 @@ class DiffusionPriorTrainer(nn.Module):
) )
if exists(cosine_decay_max_steps): if exists(cosine_decay_max_steps):
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps) self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
else: else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0) self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)

View File

@@ -1 +1 @@
__version__ = '1.10.8' __version__ = '1.12.2'

View File

@@ -26,7 +26,8 @@ setup(
install_requires=[ install_requires=[
'accelerate', 'accelerate',
'click', 'click',
'clip-anytorch>=2.4.0', 'open-clip-torch>=2.0.0,<3.0.0',
'clip-anytorch>=2.5.2',
'coca-pytorch>=0.0.5', 'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7', 'ema-pytorch>=0.0.7',
'einops>=0.4', 'einops>=0.4',

View File

@@ -556,7 +556,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
torch.manual_seed(config.seed) torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training # Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60)) init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs]) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])