mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 21:34:21 +01:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3dae43fa0e | ||
|
|
a598820012 | ||
|
|
4878762627 | ||
|
|
47ae17b36e | ||
|
|
b7e22f7da0 | ||
|
|
68de937aac | ||
|
|
097afda606 | ||
|
|
5c520db825 |
@@ -583,6 +583,7 @@ unet1 = Unet(
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
text_embed_dim = 512,
|
||||
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
|
||||
).cuda()
|
||||
|
||||
@@ -598,7 +599,8 @@ decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
timesteps = 1000,
|
||||
sample_timesteps = (250, 27),
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
@@ -41,7 +41,7 @@
|
||||
"resample_train": true,
|
||||
"preprocessing": {
|
||||
"RandomResizedCrop": {
|
||||
"size": [64, 64],
|
||||
"size": [224, 224],
|
||||
"scale": [0.75, 1.0],
|
||||
"ratio": [1.0, 1.0]
|
||||
},
|
||||
|
||||
@@ -505,6 +505,12 @@ class NoiseScheduler(nn.Module):
|
||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
)
|
||||
|
||||
def predict_noise_from_start(self, x_t, t, x0):
|
||||
return (
|
||||
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
|
||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
)
|
||||
|
||||
def p2_reweigh_loss(self, loss, times):
|
||||
if not self.has_p2_loss_reweighting:
|
||||
return loss
|
||||
@@ -911,6 +917,7 @@ class DiffusionPrior(nn.Module):
|
||||
image_size = None,
|
||||
image_channels = 3,
|
||||
timesteps = 1000,
|
||||
sample_timesteps = None,
|
||||
cond_drop_prob = 0.,
|
||||
loss_type = "l2",
|
||||
predict_x_start = True,
|
||||
@@ -924,6 +931,8 @@ class DiffusionPrior(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_timesteps = sample_timesteps
|
||||
|
||||
self.noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = beta_schedule,
|
||||
timesteps = timesteps,
|
||||
@@ -978,8 +987,6 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
if self.predict_x_start:
|
||||
x_recon = pred
|
||||
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
||||
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
||||
else:
|
||||
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
@@ -1002,21 +1009,75 @@ class DiffusionPrior(nn.Module):
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
image_embed = torch.randn(shape, device=device)
|
||||
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
|
||||
batch, device = shape[0], self.device
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
return image_embed
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
|
||||
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
||||
alpha = alphas[time]
|
||||
alpha_next = alphas[time_next]
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_start = pred
|
||||
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
|
||||
else:
|
||||
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
|
||||
pred_noise = pred
|
||||
|
||||
if not self.predict_x_start:
|
||||
x_start.clamp_(-1., 1.)
|
||||
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_start = l2norm(x_start) * self.image_embed_scale
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
|
||||
|
||||
image_embed = x_start * alpha_next.sqrt() + \
|
||||
c1 * noise + \
|
||||
c2 * pred_noise
|
||||
|
||||
return image_embed
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, *args, timesteps = None, **kwargs):
|
||||
timesteps = default(timesteps, self.noise_scheduler.num_timesteps)
|
||||
assert timesteps <= self.noise_scheduler.num_timesteps
|
||||
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
||||
|
||||
if not is_ddim:
|
||||
return self.p_sample_loop_ddpm(*args, **kwargs)
|
||||
|
||||
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
||||
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
|
||||
@@ -1051,7 +1112,15 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
|
||||
def sample(
|
||||
self,
|
||||
text,
|
||||
num_samples_per_batch = 2,
|
||||
cond_scale = 1.,
|
||||
timesteps = None
|
||||
):
|
||||
timesteps = default(timesteps, self.sample_timesteps)
|
||||
|
||||
# in the paper, what they did was
|
||||
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
||||
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
||||
@@ -1066,7 +1135,7 @@ class DiffusionPrior(nn.Module):
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
|
||||
|
||||
# retrieve original unscaled image embed
|
||||
|
||||
@@ -1468,10 +1537,12 @@ class Unet(nn.Module):
|
||||
# text encoding conditioning (optional)
|
||||
|
||||
self.text_to_cond = None
|
||||
self.text_embed_dim = None
|
||||
|
||||
if cond_on_text_encodings:
|
||||
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
|
||||
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
||||
self.text_embed_dim = text_embed_dim
|
||||
|
||||
# finer control over whether to condition on image embeddings and text encodings
|
||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||
@@ -1700,6 +1771,8 @@ class Unet(nn.Module):
|
||||
text_tokens = None
|
||||
|
||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
||||
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
text_tokens = text_tokens[:, :self.max_text_len]
|
||||
|
||||
@@ -1853,6 +1926,7 @@ class Decoder(nn.Module):
|
||||
channels = 3,
|
||||
vae = tuple(),
|
||||
timesteps = 1000,
|
||||
sample_timesteps = None,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
loss_type = 'l2',
|
||||
@@ -1876,7 +1950,8 @@ class Decoder(nn.Module):
|
||||
use_dynamic_thres = False, # from the Imagen paper
|
||||
dynamic_thres_percentile = 0.9,
|
||||
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
||||
p2_loss_weight_k = 1
|
||||
p2_loss_weight_k = 1,
|
||||
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1956,6 +2031,11 @@ class Decoder(nn.Module):
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# sampling timesteps, defaults to non-ddim with full timesteps sampling
|
||||
|
||||
self.sample_timesteps = cast_tuple(sample_timesteps, num_unets)
|
||||
self.ddim_sampling_eta = ddim_sampling_eta
|
||||
|
||||
# create noise schedulers per unet
|
||||
|
||||
if not exists(beta_schedule):
|
||||
@@ -1966,7 +2046,9 @@ class Decoder(nn.Module):
|
||||
|
||||
self.noise_schedulers = nn.ModuleList([])
|
||||
|
||||
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
|
||||
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
|
||||
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
|
||||
|
||||
noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = unet_beta_schedule,
|
||||
timesteps = timesteps,
|
||||
@@ -2067,6 +2149,26 @@ class Decoder(nn.Module):
|
||||
for unet, device in zip(self.unets, devices):
|
||||
unet.to(device)
|
||||
|
||||
def dynamic_threshold(self, x):
|
||||
""" proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance """
|
||||
|
||||
# s is the threshold amount
|
||||
# static thresholding would just be s = 1
|
||||
s = 1.
|
||||
if self.use_dynamic_thres:
|
||||
s = torch.quantile(
|
||||
rearrange(x, 'b ... -> b (...)').abs(),
|
||||
self.dynamic_thres_percentile,
|
||||
dim = -1
|
||||
)
|
||||
|
||||
s.clamp_(min = 1.)
|
||||
s = s.view(-1, *((1,) * (x.ndim - 1)))
|
||||
|
||||
# clip by threshold, depending on whether static or dynamic
|
||||
x = x.clamp(-s, s) / s
|
||||
return x
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
@@ -2081,21 +2183,7 @@ class Decoder(nn.Module):
|
||||
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
# s is the threshold amount
|
||||
# static thresholding would just be s = 1
|
||||
s = 1.
|
||||
if self.use_dynamic_thres:
|
||||
s = torch.quantile(
|
||||
rearrange(x_recon, 'b ... -> b (...)').abs(),
|
||||
self.dynamic_thres_percentile,
|
||||
dim = -1
|
||||
)
|
||||
|
||||
s.clamp_(min = 1.)
|
||||
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
|
||||
|
||||
# clip by threshold, depending on whether static or dynamic
|
||||
x_recon = x_recon.clamp(-s, s) / s
|
||||
x_recon = self.dynamic_threshold(x_recon)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
|
||||
@@ -2125,7 +2213,7 @@ class Decoder(nn.Module):
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
@@ -2153,6 +2241,62 @@ class Decoder(nn.Module):
|
||||
unnormalize_img = self.unnormalize_img(img)
|
||||
return unnormalize_img
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
||||
alpha = alphas[time]
|
||||
alpha_next = alphas[time_next]
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = pred.chunk(2, dim = 1)
|
||||
|
||||
if predict_x_start:
|
||||
x_start = pred
|
||||
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
||||
else:
|
||||
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
||||
pred_noise = pred
|
||||
|
||||
if clip_denoised:
|
||||
x_start = self.dynamic_threshold(x_start)
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
noise = torch.randn_like(img) if time_next > 0 else 0.
|
||||
|
||||
img = x_start * alpha_next.sqrt() + \
|
||||
c1 * noise + \
|
||||
c2 * pred_noise
|
||||
|
||||
img = self.unnormalize_img(img)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):
|
||||
num_timesteps = noise_scheduler.num_timesteps
|
||||
|
||||
timesteps = default(timesteps, num_timesteps)
|
||||
assert timesteps <= num_timesteps
|
||||
is_ddim = timesteps < num_timesteps
|
||||
|
||||
if not is_ddim:
|
||||
return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs)
|
||||
|
||||
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
@@ -2253,7 +2397,7 @@ class Decoder(nn.Module):
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
||||
|
||||
@@ -2282,7 +2426,8 @@ class Decoder(nn.Module):
|
||||
clip_denoised = not is_latent_diffusion,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
is_latent_diffusion = is_latent_diffusion,
|
||||
noise_scheduler = noise_scheduler
|
||||
noise_scheduler = noise_scheduler,
|
||||
timesteps = sample_timesteps
|
||||
)
|
||||
|
||||
img = vae.decode(img)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import webdataset as wds
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import fsspec
|
||||
import shutil
|
||||
@@ -255,7 +256,7 @@ def create_image_embedding_dataloader(
|
||||
)
|
||||
if shuffle_num is not None and shuffle_num > 0:
|
||||
ds.shuffle(1000)
|
||||
return wds.WebLoader(
|
||||
return DataLoader(
|
||||
ds,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
|
||||
@@ -154,6 +154,7 @@ class DiffusionPriorConfig(BaseModel):
|
||||
image_size: int
|
||||
image_channels: int = 3
|
||||
timesteps: int = 1000
|
||||
sample_timesteps: Optional[int] = None
|
||||
cond_drop_prob: float = 0.
|
||||
loss_type: str = 'l2'
|
||||
predict_x_start: bool = True
|
||||
@@ -233,6 +234,7 @@ class DecoderConfig(BaseModel):
|
||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
sample_timesteps: Optional[SingularOrIterable(int)] = None
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: ListOrTuple(str) = 'cosine'
|
||||
learned_variance: bool = True
|
||||
|
||||
@@ -21,7 +21,7 @@ import pytorch_warmup as warmup
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -76,6 +76,7 @@ def cast_torch_tensor(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
device = kwargs.pop('_device', next(model.parameters()).device)
|
||||
cast_device = kwargs.pop('_cast_device', True)
|
||||
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
|
||||
|
||||
kwargs_keys = kwargs.keys()
|
||||
all_args = (*args, *kwargs.values())
|
||||
@@ -85,6 +86,21 @@ def cast_torch_tensor(fn):
|
||||
if cast_device:
|
||||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
|
||||
if cast_deepspeed_precision:
|
||||
try:
|
||||
accelerator = model.accelerator
|
||||
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
cast_type_map = {
|
||||
"fp16": torch.half,
|
||||
"bf16": torch.bfloat16,
|
||||
"no": torch.float
|
||||
}
|
||||
precision_type = cast_type_map[accelerator.mixed_precision]
|
||||
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
except AttributeError:
|
||||
# Then this model doesn't have an accelerator
|
||||
pass
|
||||
|
||||
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||
|
||||
@@ -446,6 +462,7 @@ class DecoderTrainer(nn.Module):
|
||||
self,
|
||||
decoder,
|
||||
accelerator = None,
|
||||
dataloaders = None,
|
||||
use_ema = True,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
@@ -508,10 +525,31 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
||||
|
||||
if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
|
||||
# Then we need to make sure clip is using the correct precision or else deepspeed will error
|
||||
cast_type_map = {
|
||||
"fp16": torch.half,
|
||||
"bf16": torch.bfloat16,
|
||||
"no": torch.float
|
||||
}
|
||||
precision_type = cast_type_map[accelerator.mixed_precision]
|
||||
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
|
||||
clip = decoder.clip
|
||||
clip.to(precision_type)
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
# prepare dataloaders
|
||||
|
||||
train_loader = val_loader = None
|
||||
if exists(dataloaders):
|
||||
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
|
||||
# store optimizers
|
||||
|
||||
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
||||
@@ -675,6 +713,9 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
total_loss = 0.
|
||||
|
||||
|
||||
using_amp = self.accelerator.mixed_precision != 'no'
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.17.1'
|
||||
__version__ = '0.19.5'
|
||||
|
||||
@@ -274,6 +274,7 @@ def train(
|
||||
trainer = DecoderTrainer(
|
||||
decoder=decoder,
|
||||
accelerator=accelerator,
|
||||
dataloaders=dataloaders,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -284,7 +285,6 @@ def train(
|
||||
sample = 0
|
||||
samples_seen = 0
|
||||
val_sample = 0
|
||||
step = lambda: int(trainer.num_steps_taken(unet_number=1))
|
||||
|
||||
if tracker.can_recall:
|
||||
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||
@@ -299,6 +299,8 @@ def train(
|
||||
if not exists(unet_training_mask):
|
||||
# Then the unet mask should be true for all unets in the decoder
|
||||
unet_training_mask = [True] * trainer.num_unets
|
||||
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
|
||||
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
|
||||
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||
|
||||
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
|
||||
@@ -321,7 +323,7 @@ def train(
|
||||
last_snapshot = sample
|
||||
|
||||
if next_task == 'train':
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||
for i, (img, emb, txt) in enumerate(trainer.train_loader):
|
||||
# We want to count the total number of samples across all processes
|
||||
sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||
@@ -414,7 +416,7 @@ def train(
|
||||
timer = Timer()
|
||||
accelerator.wait_for_everyone()
|
||||
i = 0
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
|
||||
val_sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
@@ -519,6 +521,20 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
# Set up accelerator for configurable distributed training
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
if accelerator.num_processes > 1:
|
||||
# We are using distributed training and want to immediately ensure all can connect
|
||||
accelerator.print("Waiting for all processes to connect...")
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print("All processes online and connected")
|
||||
|
||||
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
|
||||
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
|
||||
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
|
||||
|
||||
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
|
||||
# This is an invalid configuration until we figure out how to handle this
|
||||
raise ValueError("DeepSpeed does not support multi-node distributed training")
|
||||
|
||||
# Set up data
|
||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||
|
||||
Reference in New Issue
Block a user