Compare commits

..

5 Commits

6 changed files with 86 additions and 21 deletions

View File

@@ -1298,4 +1298,14 @@ For detailed information on training the diffusion prior, please refer to the [d
} }
``` ```
```bibtex
@article{Salimans2022ProgressiveDF,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.00512}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.version import __version__ from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE from dalle2_pytorch.vqgan_vae import VQGanVAE

View File

@@ -100,6 +100,9 @@ def eval_decorator(fn):
return out return out
return inner return inner
def is_float_dtype(dtype):
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
def is_list_str(x): def is_list_str(x):
if not isinstance(x, (list, tuple)): if not isinstance(x, (list, tuple)):
return False return False
@@ -386,6 +389,8 @@ class OpenClipAdapter(BaseClipAdapter):
self.eos_id = 49407 self.eos_id = 49407
text_attention_final = self.find_layer('ln_final') text_attention_final = self.find_layer('ln_final')
self._dim_latent = text_attention_final.weight.shape[0]
self.handle = text_attention_final.register_forward_hook(self._hook) self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1] self.clip_normalize = preprocess.transforms[-1]
self.cleared = False self.cleared = False
@@ -405,7 +410,7 @@ class OpenClipAdapter(BaseClipAdapter):
@property @property
def dim_latent(self): def dim_latent(self):
return 512 return self._dim_latent
@property @property
def image_size(self): def image_size(self):
@@ -614,7 +619,7 @@ class NoiseScheduler(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
return ( return (
@@ -622,6 +627,12 @@ class NoiseScheduler(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
) )
def calculate_v(self, x_start, t, noise = None):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
def q_sample_from_to(self, x_from, from_t, to_t, noise = None): def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
shape = x_from.shape shape = x_from.shape
noise = default(noise, lambda: torch.randn_like(x_from)) noise = default(noise, lambda: torch.randn_like(x_from))
@@ -633,6 +644,12 @@ class NoiseScheduler(nn.Module):
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
def predict_start_from_v(self, x_t, t, v):
return (
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_start_from_noise(self, x_t, t, noise): def predict_start_from_noise(self, x_t, t, noise):
return ( return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
@@ -968,6 +985,8 @@ class DiffusionPriorNetwork(nn.Module):
Rearrange('b (n d) -> b n d', n = num_text_embeds) Rearrange('b (n d) -> b n d', n = num_text_embeds)
) )
self.continuous_embedded_time = not exists(num_timesteps)
self.to_time_embeds = nn.Sequential( self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds) Rearrange('b (n d) -> b n d', n = num_time_embeds)
@@ -1095,6 +1114,9 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right # but let's just do it right
if self.continuous_embedded_time:
diffusion_timesteps = diffusion_timesteps.type(dtype)
time_embed = self.to_time_embeds(diffusion_timesteps) time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
@@ -1136,6 +1158,7 @@ class DiffusionPrior(nn.Module):
image_cond_drop_prob = None, image_cond_drop_prob = None,
loss_type = "l2", loss_type = "l2",
predict_x_start = True, predict_x_start = True,
predict_v = False,
beta_schedule = "cosine", beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs) sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
@@ -1187,6 +1210,7 @@ class DiffusionPrior(nn.Module):
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
self.predict_x_start = predict_x_start self.predict_x_start = predict_x_start
self.predict_v = predict_v # takes precedence over predict_x_start
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 # @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
@@ -1216,7 +1240,9 @@ class DiffusionPrior(nn.Module):
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond) pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
if self.predict_x_start: if self.predict_v:
x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)
elif self.predict_x_start:
x_start = pred x_start = pred
else: else:
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -1289,7 +1315,9 @@ class DiffusionPrior(nn.Module):
# derive x0 # derive x0
if self.predict_x_start: if self.predict_v:
x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)
elif self.predict_x_start:
x_start = pred x_start = pred
else: else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise) x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise)
@@ -1304,7 +1332,7 @@ class DiffusionPrior(nn.Module):
# predict noise # predict noise
if self.predict_x_start: if self.predict_x_start or self.predict_v:
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start) pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
else: else:
pred_noise = pred pred_noise = pred
@@ -1362,7 +1390,12 @@ class DiffusionPrior(nn.Module):
if self.predict_x_start and self.training_clamp_l2norm: if self.predict_x_start and self.training_clamp_l2norm:
pred = self.l2norm_clamp_embed(pred) pred = self.l2norm_clamp_embed(pred)
target = noise if not self.predict_x_start else image_embed if self.predict_v:
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
elif self.predict_x_start:
target = image_embed
else:
target = noise
loss = self.noise_scheduler.loss_fn(pred, target) loss = self.noise_scheduler.loss_fn(pred, target)
return loss return loss
@@ -1432,7 +1465,7 @@ class DiffusionPrior(nn.Module):
**kwargs **kwargs
): ):
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied' assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied' assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization' assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
if exists(image): if exists(image):
@@ -1538,6 +1571,8 @@ class SinusoidalPosEmb(nn.Module):
def forward(self, x): def forward(self, x):
dtype, device = x.dtype, x.device dtype, device = x.dtype, x.device
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
half_dim = self.dim // 2 half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb) emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
@@ -2436,6 +2471,7 @@ class Decoder(nn.Module):
loss_type = 'l2', loss_type = 'l2',
beta_schedule = None, beta_schedule = None,
predict_x_start = False, predict_x_start = False,
predict_v = False,
predict_x_start_for_latent_diffusion = False, predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops) random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
@@ -2608,6 +2644,10 @@ class Decoder(nn.Module):
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes)) self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
# predict v
self.predict_v = cast_tuple(predict_v, len(unets))
# input image range # input image range
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.) self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
@@ -2719,14 +2759,16 @@ class Decoder(nn.Module):
x = x.clamp(-s, s) / s x = x.clamp(-s, s) / s
return x return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None): def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level)) model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output) pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
if predict_x_start: if predict_v:
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
elif predict_x_start:
x_start = pred x_start = pred
else: else:
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -2753,9 +2795,9 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance, x_start return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None): def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level) model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
noise = torch.randn_like(x) noise = torch.randn_like(x)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
@@ -2770,6 +2812,7 @@ class Decoder(nn.Module):
image_embed, image_embed,
noise_scheduler, noise_scheduler,
predict_x_start = False, predict_x_start = False,
predict_v = False,
learned_variance = False, learned_variance = False,
clip_denoised = True, clip_denoised = True,
lowres_cond_img = None, lowres_cond_img = None,
@@ -2828,6 +2871,7 @@ class Decoder(nn.Module):
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level, lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
predict_v = predict_v,
noise_scheduler = noise_scheduler, noise_scheduler = noise_scheduler,
learned_variance = learned_variance, learned_variance = learned_variance,
clip_denoised = clip_denoised clip_denoised = clip_denoised
@@ -2853,6 +2897,7 @@ class Decoder(nn.Module):
timesteps, timesteps,
eta = 1., eta = 1.,
predict_x_start = False, predict_x_start = False,
predict_v = False,
learned_variance = False, learned_variance = False,
clip_denoised = True, clip_denoised = True,
lowres_cond_img = None, lowres_cond_img = None,
@@ -2914,7 +2959,9 @@ class Decoder(nn.Module):
# predict x0 # predict x0
if predict_x_start: if predict_v:
x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)
elif predict_x_start:
x_start = pred x_start = pred
else: else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred) x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
@@ -2926,8 +2973,8 @@ class Decoder(nn.Module):
# predict noise # predict noise
if predict_x_start: if predict_x_start or predict_v:
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred) pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
else: else:
pred_noise = pred pred_noise = pred
@@ -2963,7 +3010,7 @@ class Decoder(nn.Module):
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs) return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None): def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1] # normalize to [-1, 1]
@@ -3008,7 +3055,12 @@ class Decoder(nn.Module):
pred, _ = self.parse_unet_output(learned_variance, unet_output) pred, _ = self.parse_unet_output(learned_variance, unet_output)
target = noise if not predict_x_start else x_start if predict_v:
target = noise_scheduler.calculate_v(x_start, times, noise)
elif predict_x_start:
target = x_start
else:
target = noise
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none') loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean') loss = reduce(loss, 'b ... -> b (...)', 'mean')
@@ -3094,7 +3146,7 @@ class Decoder(nn.Module):
num_unets = self.num_unets num_unets = self.num_unets
cond_scale = cast_tuple(cond_scale, num_unets) cond_scale = cast_tuple(cond_scale, num_unets)
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)): for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
if unet_number < start_at_unet_number: if unet_number < start_at_unet_number:
continue # It's the easiest way to do it continue # It's the easiest way to do it
@@ -3130,6 +3182,7 @@ class Decoder(nn.Module):
text_encodings = text_encodings, text_encodings = text_encodings,
cond_scale = unet_cond_scale, cond_scale = unet_cond_scale,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
predict_v = predict_v,
learned_variance = learned_variance, learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion, clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
@@ -3169,6 +3222,7 @@ class Decoder(nn.Module):
lowres_conditioner = self.lowres_conds[unet_index] lowres_conditioner = self.lowres_conds[unet_index]
target_image_size = self.image_sizes[unet_index] target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index] predict_x_start = self.predict_x_start[unet_index]
predict_v = self.predict_v[unet_index]
random_crop_size = self.random_crop_sizes[unet_index] random_crop_size = self.random_crop_sizes[unet_index]
learned_variance = self.learned_variance[unet_index] learned_variance = self.learned_variance[unet_index]
b, c, h, w, device, = *image.shape, image.device b, c, h, w, device, = *image.shape, image.device
@@ -3207,7 +3261,7 @@ class Decoder(nn.Module):
image = vae.encode(image) image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img) lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level) losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
if not return_lowres_cond_image: if not return_lowres_cond_image:
return losses return losses

View File

@@ -307,6 +307,7 @@ class DecoderTrainConfig(BaseModel):
wd: SingularOrIterable[float] = 0.01 wd: SingularOrIterable[float] = 0.01
warmup_steps: Optional[SingularOrIterable[int]] = None warmup_steps: Optional[SingularOrIterable[int]] = None
find_unused_parameters: bool = True find_unused_parameters: bool = True
static_graph: bool = True
max_grad_norm: SingularOrIterable[float] = 0.5 max_grad_norm: SingularOrIterable[float] = 0.5
save_every_n_samples: int = 100000 save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset

View File

@@ -1 +1 @@
__version__ = '1.10.7' __version__ = '1.11.0'

View File

@@ -556,7 +556,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
torch.manual_seed(config.seed) torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training # Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60)) init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs]) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])