Compare commits

...

14 Commits

Author SHA1 Message Date
Phil Wang
1d9ef99288 add PixelShuffleUpsample thanks to @MalumaDev and @marunine for running the experiment and verifyng absence of checkboard artifacts 2022-07-11 16:07:23 -07:00
Phil Wang
bdd62c24b3 zero init final projection in unet, since openai and @crowsonkb are both doing it 2022-07-11 13:22:06 -07:00
Phil Wang
1f1557c614 make it so even if text mask is omitted, it will be derived based on whether text encodings are all 0s or not, simplify dataloading 2022-07-11 10:56:19 -07:00
Aidan Dempster
1a217e99e3 Unet parameter count is now shown (#202) 2022-07-10 16:45:59 -07:00
Phil Wang
7ea314e2f0 allow for final l2norm clamping of the sampled image embed 2022-07-10 09:44:38 -07:00
Phil Wang
4173e88121 more accurate readme 2022-07-09 20:57:26 -07:00
Phil Wang
3dae43fa0e fix misnamed variable, thanks to @nousr 2022-07-09 19:01:37 -07:00
Phil Wang
a598820012 do not noise for the last step in ddim 2022-07-09 18:38:40 -07:00
Phil Wang
4878762627 fix for small validation bug for sampling steps 2022-07-09 17:31:54 -07:00
Phil Wang
47ae17b36e more informative error for something that tripped me up 2022-07-09 17:28:14 -07:00
Phil Wang
b7e22f7da0 complete ddim integration of diffusion prior as well as decoder for each unet, feature complete for https://github.com/lucidrains/DALLE2-pytorch/issues/157 2022-07-09 17:25:34 -07:00
Romain Beaumont
68de937aac Fix decoder test by fixing the resizing output size (#197) 2022-07-09 07:48:07 -07:00
Phil Wang
097afda606 0.18.0 2022-07-08 18:18:38 -07:00
Aidan Dempster
5c520db825 Added deepspeed support (#195) 2022-07-08 18:18:08 -07:00
8 changed files with 319 additions and 56 deletions

View File

@@ -355,7 +355,8 @@ prior_network = DiffusionPriorNetwork(
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
timesteps = 1000,
sample_timesteps = 64,
cond_drop_prob = 0.2
).cuda()
@@ -583,6 +584,7 @@ unet1 = Unet(
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8),
text_embed_dim = 512,
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda()
@@ -598,7 +600,8 @@ decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
timesteps = 1000,
sample_timesteps = (250, 27),
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()

View File

@@ -41,7 +41,7 @@
"resample_train": true,
"preprocessing": {
"RandomResizedCrop": {
"size": [64, 64],
"size": [224, 224],
"scale": [0.75, 1.0],
"ratio": [1.0, 1.0]
},

View File

@@ -77,6 +77,11 @@ def cast_tuple(val, length = None):
def module_device(module):
return next(module.parameters()).device
def zero_init_(m):
nn.init.zeros_(m.weight)
if exists(m.bias):
nn.init.zeros_(m.bias)
@contextmanager
def null_context(*args, **kwargs):
yield
@@ -220,6 +225,7 @@ class XClipAdapter(BaseClipAdapter):
encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
@torch.no_grad()
@@ -255,6 +261,7 @@ class CoCaAdapter(BaseClipAdapter):
text = text[..., :self.max_text_len]
text_mask = text != 0
text_embed, text_encodings = self.clip.embed_text(text)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(text_embed, text_encodings, text_mask)
@torch.no_grad()
@@ -314,6 +321,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
del self.text_encodings
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
@@ -505,6 +513,12 @@ class NoiseScheduler(nn.Module):
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
return (
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def p2_reweigh_loss(self, loss, times):
if not self.has_p2_loss_reweighting:
return loss
@@ -911,19 +925,23 @@ class DiffusionPrior(nn.Module):
image_size = None,
image_channels = 3,
timesteps = 1000,
sample_timesteps = None,
cond_drop_prob = 0.,
loss_type = "l2",
predict_x_start = True,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False,
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
training_clamp_l2norm = False,
init_image_embed_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
super().__init__()
self.sample_timesteps = sample_timesteps
self.noise_scheduler = NoiseScheduler(
beta_schedule = beta_schedule,
timesteps = timesteps,
@@ -954,23 +972,32 @@ class DiffusionPrior(nn.Module):
self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
self.predict_x_start = predict_x_start
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
# whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm
self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
# device tracker
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
@property
def device(self):
return self._dummy.device
def l2norm_clamp_embed(self, image_embed):
return l2norm(image_embed) * self.image_embed_scale
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -978,8 +1005,6 @@ class DiffusionPrior(nn.Module):
if self.predict_x_start:
x_recon = pred
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else:
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -1002,21 +1027,81 @@ class DiffusionPrior(nn.Module):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
device = self.device
b = shape[0]
image_embed = torch.randn(shape, device=device)
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
batch, device = shape[0], self.device
image_embed = torch.randn(shape, device = device)
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long)
times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed)
return image_embed
@torch.no_grad()
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
image_embed = torch.randn(shape, device = device)
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
if self.predict_x_start:
x_start = pred
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
pred_noise = pred
if not self.predict_x_start:
x_start.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm:
x_start = self.l2norm_clamp_embed(x_start)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
image_embed = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
if self.predict_x_start and self.sampling_final_clamp_l2norm:
image_embed = self.l2norm_clamp_embed(image_embed)
return image_embed
@torch.no_grad()
def p_sample_loop(self, *args, timesteps = None, **kwargs):
timesteps = default(timesteps, self.noise_scheduler.num_timesteps)
assert timesteps <= self.noise_scheduler.num_timesteps
is_ddim = timesteps < self.noise_scheduler.num_timesteps
if not is_ddim:
return self.p_sample_loop_ddpm(*args, **kwargs)
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
@@ -1030,7 +1115,7 @@ class DiffusionPrior(nn.Module):
)
if self.predict_x_start and self.training_clamp_l2norm:
pred = l2norm(pred) * self.image_embed_scale
pred = self.l2norm_clamp_embed(pred)
target = noise if not self.predict_x_start else image_embed
@@ -1051,7 +1136,15 @@ class DiffusionPrior(nn.Module):
@torch.no_grad()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
def sample(
self,
text,
num_samples_per_batch = 2,
cond_scale = 1.,
timesteps = None
):
timesteps = default(timesteps, self.sample_timesteps)
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
@@ -1066,7 +1159,7 @@ class DiffusionPrior(nn.Module):
if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
# retrieve original unscaled image embed
@@ -1112,6 +1205,7 @@ class DiffusionPrior(nn.Module):
if self.condition_on_text_encodings:
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
# timestep conditioning from ddpm
@@ -1129,16 +1223,35 @@ class DiffusionPrior(nn.Module):
# decoder
def ConvTransposeUpsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
class PixelShuffleUpsample(nn.Module):
"""
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
"""
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
def NearestUpsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, dim_out, 3, padding = 1)
)
self.net = nn.Sequential(
conv,
nn.SiLU(),
nn.PixelShuffle(2)
)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)
def Downsample(dim, *, dim_out = None):
dim_out = default(dim_out, dim)
@@ -1402,7 +1515,7 @@ class Unet(nn.Module):
cross_embed_downsample_kernel_sizes = (2, 4),
memory_efficient = False,
scale_skip_connection = False,
nearest_upsample = False,
pixel_shuffle_upsample = True,
final_conv_kernel_size = 1,
**kwargs
):
@@ -1468,10 +1581,12 @@ class Unet(nn.Module):
# text encoding conditioning (optional)
self.text_to_cond = None
self.text_embed_dim = None
if cond_on_text_encodings:
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
self.text_embed_dim = text_embed_dim
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
@@ -1514,7 +1629,7 @@ class Unet(nn.Module):
# upsample klass
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# give memory efficient unet an initial resnet block
@@ -1578,6 +1693,8 @@ class Unet(nn.Module):
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters(
@@ -1700,6 +1817,8 @@ class Unet(nn.Module):
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
text_tokens = self.text_to_cond(text_encodings)
text_tokens = text_tokens[:, :self.max_text_len]
@@ -1853,6 +1972,7 @@ class Decoder(nn.Module):
channels = 3,
vae = tuple(),
timesteps = 1000,
sample_timesteps = None,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l2',
@@ -1876,7 +1996,8 @@ class Decoder(nn.Module):
use_dynamic_thres = False, # from the Imagen paper
dynamic_thres_percentile = 0.9,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1
p2_loss_weight_k = 1,
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
):
super().__init__()
@@ -1956,6 +2077,11 @@ class Decoder(nn.Module):
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# sampling timesteps, defaults to non-ddim with full timesteps sampling
self.sample_timesteps = cast_tuple(sample_timesteps, num_unets)
self.ddim_sampling_eta = ddim_sampling_eta
# create noise schedulers per unet
if not exists(beta_schedule):
@@ -1966,7 +2092,9 @@ class Decoder(nn.Module):
self.noise_schedulers = nn.ModuleList([])
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule,
timesteps = timesteps,
@@ -2067,6 +2195,26 @@ class Decoder(nn.Module):
for unet, device in zip(self.unets, devices):
unet.to(device)
def dynamic_threshold(self, x):
""" proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance """
# s is the threshold amount
# static thresholding would just be s = 1
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = s.view(-1, *((1,) * (x.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x = x.clamp(-s, s) / s
return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -2081,21 +2229,7 @@ class Decoder(nn.Module):
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised:
# s is the threshold amount
# static thresholding would just be s = 1
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x_recon, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x_recon = x_recon.clamp(-s, s) / s
x_recon = self.dynamic_threshold(x_recon)
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
@@ -2125,7 +2259,7 @@ class Decoder(nn.Module):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.device
b = shape[0]
@@ -2153,6 +2287,62 @@ class Decoder(nn.Module):
unnormalize_img = self.unnormalize_img(img)
return unnormalize_img
@torch.no_grad()
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
img = torch.randn(shape, device = device)
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
if predict_x_start:
x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
if clip_denoised:
x_start = self.dynamic_threshold(x_start)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if time_next > 0 else 0.
img = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
img = self.unnormalize_img(img)
return img
@torch.no_grad()
def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):
num_timesteps = noise_scheduler.num_timesteps
timesteps = default(timesteps, num_timesteps)
assert timesteps <= num_timesteps
is_ddim = timesteps < num_timesteps
if not is_ddim:
return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs)
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -2250,10 +2440,13 @@ class Decoder(nn.Module):
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
if self.condition_on_text_encodings:
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
img = None
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
@@ -2282,7 +2475,8 @@ class Decoder(nn.Module):
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img,
is_latent_diffusion = is_latent_diffusion,
noise_scheduler = noise_scheduler
noise_scheduler = noise_scheduler,
timesteps = sample_timesteps
)
img = vae.decode(img)
@@ -2332,6 +2526,9 @@ class Decoder(nn.Module):
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
if self.condition_on_text_encodings:
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)

View File

@@ -1,6 +1,7 @@
import os
import webdataset as wds
import torch
from torch.utils.data import DataLoader
import numpy as np
import fsspec
import shutil
@@ -255,7 +256,7 @@ def create_image_embedding_dataloader(
)
if shuffle_num is not None and shuffle_num > 0:
ds.shuffle(1000)
return wds.WebLoader(
return DataLoader(
ds,
num_workers=num_workers,
batch_size=batch_size,

View File

@@ -154,6 +154,7 @@ class DiffusionPriorConfig(BaseModel):
image_size: int
image_channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[int] = None
cond_drop_prob: float = 0.
loss_type: str = 'l2'
predict_x_start: bool = True
@@ -233,6 +234,7 @@ class DecoderConfig(BaseModel):
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable(int)] = None
loss_type: str = 'l2'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True

View File

@@ -21,7 +21,7 @@ import pytorch_warmup as warmup
from ema_pytorch import EMA
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
import numpy as np
@@ -76,6 +76,7 @@ def cast_torch_tensor(fn):
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
@@ -85,6 +86,21 @@ def cast_torch_tensor(fn):
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
if cast_deepspeed_precision:
try:
accelerator = model.accelerator
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
except AttributeError:
# Then this model doesn't have an accelerator
pass
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
@@ -446,6 +462,7 @@ class DecoderTrainer(nn.Module):
self,
decoder,
accelerator = None,
dataloaders = None,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
@@ -508,10 +525,31 @@ class DecoderTrainer(nn.Module):
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
# Then we need to make sure clip is using the correct precision or else deepspeed will error
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
clip = decoder.clip
clip.to(precision_type)
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = decoder
# prepare dataloaders
train_loader = val_loader = None
if exists(dataloaders):
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
self.train_loader = train_loader
self.val_loader = val_loader
# store optimizers
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
@@ -675,6 +713,9 @@ class DecoderTrainer(nn.Module):
total_loss = 0.
using_amp = self.accelerator.mixed_precision != 'no'
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)

View File

@@ -1 +1 @@
__version__ = '0.17.1'
__version__ = '0.21.0'

View File

@@ -274,6 +274,7 @@ def train(
trainer = DecoderTrainer(
decoder=decoder,
accelerator=accelerator,
dataloaders=dataloaders,
**kwargs
)
@@ -284,7 +285,6 @@ def train(
sample = 0
samples_seen = 0
val_sample = 0
step = lambda: int(trainer.num_steps_taken(unet_number=1))
if tracker.can_recall:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
@@ -299,6 +299,8 @@ def train(
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
@@ -321,7 +323,7 @@ def train(
last_snapshot = sample
if next_task == 'train':
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
for i, (img, emb, txt) in enumerate(trainer.train_loader):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
@@ -414,7 +416,7 @@ def train(
timer = Timer()
accelerator.wait_for_everyone()
i = 0
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()
@@ -519,6 +521,20 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
if accelerator.num_processes > 1:
# We are using distributed training and want to immediately ensure all can connect
accelerator.print("Waiting for all processes to connect...")
accelerator.wait_for_everyone()
accelerator.print("All processes online and connected")
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
# This is an invalid configuration until we figure out how to handle this
raise ValueError("DeepSpeed does not support multi-node distributed training")
# Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
@@ -541,7 +557,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Create the decoder model and print basic info
decoder = config.decoder.create()
num_parameters = sum(p.numel() for p in decoder.parameters())
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
@@ -570,7 +586,10 @@ def initialize_training(config: TrainDecoderConfig, config_path):
accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
accelerator.print(f"Number of parameters: {num_parameters}")
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
for i, unet in enumerate(decoder.unets):
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=accelerator.device,