|
|
|
|
@@ -100,6 +100,9 @@ def eval_decorator(fn):
|
|
|
|
|
return out
|
|
|
|
|
return inner
|
|
|
|
|
|
|
|
|
|
def is_float_dtype(dtype):
|
|
|
|
|
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
|
|
|
|
|
|
|
|
|
|
def is_list_str(x):
|
|
|
|
|
if not isinstance(x, (list, tuple)):
|
|
|
|
|
return False
|
|
|
|
|
@@ -314,7 +317,10 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|
|
|
|
self.eos_id = 49407 # for handling 0 being also '!'
|
|
|
|
|
|
|
|
|
|
text_attention_final = self.find_layer('ln_final')
|
|
|
|
|
|
|
|
|
|
self.dim_latent_ = text_attention_final.weight.shape[0]
|
|
|
|
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
|
|
|
|
|
|
|
|
|
self.clip_normalize = preprocess.transforms[-1]
|
|
|
|
|
self.cleared = False
|
|
|
|
|
|
|
|
|
|
@@ -333,7 +339,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def dim_latent(self):
|
|
|
|
|
return 512
|
|
|
|
|
return self.dim_latent_
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def image_size(self):
|
|
|
|
|
@@ -354,6 +360,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|
|
|
|
is_eos_id = (text == self.eos_id)
|
|
|
|
|
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
|
|
|
|
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
|
|
|
|
text_mask = text_mask & (text != 0)
|
|
|
|
|
assert not self.cleared
|
|
|
|
|
|
|
|
|
|
text_embed = self.clip.encode_text(text)
|
|
|
|
|
@@ -383,6 +390,8 @@ class OpenClipAdapter(BaseClipAdapter):
|
|
|
|
|
self.eos_id = 49407
|
|
|
|
|
|
|
|
|
|
text_attention_final = self.find_layer('ln_final')
|
|
|
|
|
self._dim_latent = text_attention_final.weight.shape[0]
|
|
|
|
|
|
|
|
|
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
|
|
|
|
self.clip_normalize = preprocess.transforms[-1]
|
|
|
|
|
self.cleared = False
|
|
|
|
|
@@ -402,11 +411,14 @@ class OpenClipAdapter(BaseClipAdapter):
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def dim_latent(self):
|
|
|
|
|
return 512
|
|
|
|
|
return self._dim_latent
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def image_size(self):
|
|
|
|
|
return self.clip.visual.image_size
|
|
|
|
|
image_size = self.clip.visual.image_size
|
|
|
|
|
if isinstance(image_size, tuple):
|
|
|
|
|
return max(image_size)
|
|
|
|
|
return image_size
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def image_channels(self):
|
|
|
|
|
@@ -423,6 +435,7 @@ class OpenClipAdapter(BaseClipAdapter):
|
|
|
|
|
is_eos_id = (text == self.eos_id)
|
|
|
|
|
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
|
|
|
|
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
|
|
|
|
text_mask = text_mask & (text != 0)
|
|
|
|
|
assert not self.cleared
|
|
|
|
|
|
|
|
|
|
text_embed = self.clip.encode_text(text)
|
|
|
|
|
@@ -608,7 +621,7 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
|
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
|
|
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
|
|
|
def q_sample(self, x_start, t, noise = None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
@@ -616,6 +629,12 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def calculate_v(self, x_start, t, noise = None):
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
|
|
|
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
|
|
|
|
|
shape = x_from.shape
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_from))
|
|
|
|
|
@@ -627,6 +646,12 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
|
|
|
|
|
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
|
|
|
|
|
|
|
|
|
|
def predict_start_from_v(self, x_t, t, v):
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
|
|
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def predict_start_from_noise(self, x_t, t, noise):
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
|
|
|
|
@@ -962,6 +987,8 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
Rearrange('b (n d) -> b n d', n = num_text_embeds)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.continuous_embedded_time = not exists(num_timesteps)
|
|
|
|
|
|
|
|
|
|
self.to_time_embeds = nn.Sequential(
|
|
|
|
|
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
|
|
|
|
Rearrange('b (n d) -> b n d', n = num_time_embeds)
|
|
|
|
|
@@ -1089,12 +1116,15 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
|
|
|
|
# but let's just do it right
|
|
|
|
|
|
|
|
|
|
if self.continuous_embedded_time:
|
|
|
|
|
diffusion_timesteps = diffusion_timesteps.type(dtype)
|
|
|
|
|
|
|
|
|
|
time_embed = self.to_time_embeds(diffusion_timesteps)
|
|
|
|
|
|
|
|
|
|
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
|
|
|
|
|
|
|
|
|
if self.self_cond:
|
|
|
|
|
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
|
|
|
|
|
learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
|
|
|
|
|
|
|
|
|
|
tokens = torch.cat((
|
|
|
|
|
text_encodings,
|
|
|
|
|
@@ -1130,6 +1160,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
image_cond_drop_prob = None,
|
|
|
|
|
loss_type = "l2",
|
|
|
|
|
predict_x_start = True,
|
|
|
|
|
predict_v = False,
|
|
|
|
|
beta_schedule = "cosine",
|
|
|
|
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
|
|
|
|
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
|
|
|
|
|
@@ -1181,6 +1212,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
|
|
|
|
|
|
|
|
|
self.predict_x_start = predict_x_start
|
|
|
|
|
self.predict_v = predict_v # takes precedence over predict_x_start
|
|
|
|
|
|
|
|
|
|
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
|
|
|
|
|
|
|
|
|
@@ -1210,7 +1242,9 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start:
|
|
|
|
|
if self.predict_v:
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)
|
|
|
|
|
elif self.predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
else:
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
|
|
|
|
@@ -1283,10 +1317,12 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
# derive x0
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start:
|
|
|
|
|
if self.predict_v:
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)
|
|
|
|
|
elif self.predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
else:
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise)
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
|
|
|
|
|
|
|
|
|
|
# clip x0 before maybe predicting noise
|
|
|
|
|
|
|
|
|
|
@@ -1298,10 +1334,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
# predict noise
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start:
|
|
|
|
|
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
|
|
|
|
|
else:
|
|
|
|
|
pred_noise = pred
|
|
|
|
|
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
|
|
|
|
|
|
|
|
|
|
if time_next < 0:
|
|
|
|
|
image_embed = x_start
|
|
|
|
|
@@ -1356,7 +1389,12 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
if self.predict_x_start and self.training_clamp_l2norm:
|
|
|
|
|
pred = self.l2norm_clamp_embed(pred)
|
|
|
|
|
|
|
|
|
|
target = noise if not self.predict_x_start else image_embed
|
|
|
|
|
if self.predict_v:
|
|
|
|
|
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
|
|
|
|
|
elif self.predict_x_start:
|
|
|
|
|
target = image_embed
|
|
|
|
|
else:
|
|
|
|
|
target = noise
|
|
|
|
|
|
|
|
|
|
loss = self.noise_scheduler.loss_fn(pred, target)
|
|
|
|
|
return loss
|
|
|
|
|
@@ -1426,7 +1464,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
|
|
|
|
|
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
|
|
|
|
|
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
|
|
|
|
|
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
|
|
|
|
|
|
|
|
|
if exists(image):
|
|
|
|
|
@@ -1532,6 +1570,8 @@ class SinusoidalPosEmb(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
dtype, device = x.dtype, x.device
|
|
|
|
|
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
|
|
|
|
|
|
|
|
|
|
half_dim = self.dim // 2
|
|
|
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
|
|
|
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
|
|
|
|
@@ -2430,6 +2470,7 @@ class Decoder(nn.Module):
|
|
|
|
|
loss_type = 'l2',
|
|
|
|
|
beta_schedule = None,
|
|
|
|
|
predict_x_start = False,
|
|
|
|
|
predict_v = False,
|
|
|
|
|
predict_x_start_for_latent_diffusion = False,
|
|
|
|
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
|
|
|
|
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
|
|
|
|
@@ -2452,7 +2493,7 @@ class Decoder(nn.Module):
|
|
|
|
|
dynamic_thres_percentile = 0.95,
|
|
|
|
|
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
|
|
|
|
p2_loss_weight_k = 1,
|
|
|
|
|
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
|
|
|
|
|
ddim_sampling_eta = 0. # can be set to 0. for deterministic sampling afaict
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
@@ -2602,6 +2643,10 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
|
|
|
|
|
|
|
|
|
# predict v
|
|
|
|
|
|
|
|
|
|
self.predict_v = cast_tuple(predict_v, len(unets))
|
|
|
|
|
|
|
|
|
|
# input image range
|
|
|
|
|
|
|
|
|
|
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
|
|
|
|
|
@@ -2682,11 +2727,16 @@ class Decoder(nn.Module):
|
|
|
|
|
if exists(unet_number):
|
|
|
|
|
unet = self.get_unet(unet_number)
|
|
|
|
|
|
|
|
|
|
# devices
|
|
|
|
|
|
|
|
|
|
cuda, cpu = torch.device('cuda'), torch.device('cpu')
|
|
|
|
|
|
|
|
|
|
self.cuda()
|
|
|
|
|
|
|
|
|
|
devices = [module_device(unet) for unet in self.unets]
|
|
|
|
|
self.unets.cpu()
|
|
|
|
|
unet.cuda()
|
|
|
|
|
|
|
|
|
|
self.unets.to(cpu)
|
|
|
|
|
unet.to(cuda)
|
|
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
@@ -2713,14 +2763,16 @@ class Decoder(nn.Module):
|
|
|
|
|
x = x.clamp(-s, s) / s
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
|
|
|
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
|
|
|
|
|
|
|
|
|
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
|
|
|
|
|
|
|
|
|
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
|
|
|
|
|
|
|
|
|
|
if predict_x_start:
|
|
|
|
|
if predict_v:
|
|
|
|
|
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
|
|
|
|
|
elif predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
else:
|
|
|
|
|
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
|
|
|
|
@@ -2747,9 +2799,9 @@ class Decoder(nn.Module):
|
|
|
|
|
return model_mean, posterior_variance, posterior_log_variance, x_start
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
|
|
|
|
b, *_, device = *x.shape, x.device
|
|
|
|
|
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
noise = torch.randn_like(x)
|
|
|
|
|
# no noise when t == 0
|
|
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
|
|
|
@@ -2764,6 +2816,7 @@ class Decoder(nn.Module):
|
|
|
|
|
image_embed,
|
|
|
|
|
noise_scheduler,
|
|
|
|
|
predict_x_start = False,
|
|
|
|
|
predict_v = False,
|
|
|
|
|
learned_variance = False,
|
|
|
|
|
clip_denoised = True,
|
|
|
|
|
lowres_cond_img = None,
|
|
|
|
|
@@ -2822,6 +2875,7 @@ class Decoder(nn.Module):
|
|
|
|
|
lowres_cond_img = lowres_cond_img,
|
|
|
|
|
lowres_noise_level = lowres_noise_level,
|
|
|
|
|
predict_x_start = predict_x_start,
|
|
|
|
|
predict_v = predict_v,
|
|
|
|
|
noise_scheduler = noise_scheduler,
|
|
|
|
|
learned_variance = learned_variance,
|
|
|
|
|
clip_denoised = clip_denoised
|
|
|
|
|
@@ -2847,6 +2901,7 @@ class Decoder(nn.Module):
|
|
|
|
|
timesteps,
|
|
|
|
|
eta = 1.,
|
|
|
|
|
predict_x_start = False,
|
|
|
|
|
predict_v = False,
|
|
|
|
|
learned_variance = False,
|
|
|
|
|
clip_denoised = True,
|
|
|
|
|
lowres_cond_img = None,
|
|
|
|
|
@@ -2908,7 +2963,9 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
# predict x0
|
|
|
|
|
|
|
|
|
|
if predict_x_start:
|
|
|
|
|
if predict_v:
|
|
|
|
|
x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)
|
|
|
|
|
elif predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
else:
|
|
|
|
|
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
|
|
|
|
@@ -2920,10 +2977,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
# predict noise
|
|
|
|
|
|
|
|
|
|
if predict_x_start:
|
|
|
|
|
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
|
|
|
|
else:
|
|
|
|
|
pred_noise = pred
|
|
|
|
|
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
|
|
|
|
|
|
|
|
|
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
|
|
|
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
|
|
|
|
@@ -2957,7 +3011,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
|
|
|
|
|
|
|
|
|
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
|
|
|
|
|
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
# normalize to [-1, 1]
|
|
|
|
|
@@ -3002,7 +3056,12 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
|
|
|
|
|
|
|
|
|
target = noise if not predict_x_start else x_start
|
|
|
|
|
if predict_v:
|
|
|
|
|
target = noise_scheduler.calculate_v(x_start, times, noise)
|
|
|
|
|
elif predict_x_start:
|
|
|
|
|
target = x_start
|
|
|
|
|
else:
|
|
|
|
|
target = noise
|
|
|
|
|
|
|
|
|
|
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
|
|
|
|
|
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
|
|
|
|
@@ -3060,7 +3119,8 @@ class Decoder(nn.Module):
|
|
|
|
|
distributed = False,
|
|
|
|
|
inpaint_image = None,
|
|
|
|
|
inpaint_mask = None,
|
|
|
|
|
inpaint_resample_times = 5
|
|
|
|
|
inpaint_resample_times = 5,
|
|
|
|
|
one_unet_in_gpu_at_time = True
|
|
|
|
|
):
|
|
|
|
|
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
|
|
|
|
|
|
|
|
|
|
@@ -3083,16 +3143,17 @@ class Decoder(nn.Module):
|
|
|
|
|
assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)
|
|
|
|
|
prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]
|
|
|
|
|
img = resize_image_to(image, prev_unet_output_size, nearest = True)
|
|
|
|
|
|
|
|
|
|
is_cuda = next(self.parameters()).is_cuda
|
|
|
|
|
|
|
|
|
|
num_unets = self.num_unets
|
|
|
|
|
cond_scale = cast_tuple(cond_scale, num_unets)
|
|
|
|
|
|
|
|
|
|
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
|
|
|
|
for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
|
|
|
|
if unet_number < start_at_unet_number:
|
|
|
|
|
continue # It's the easiest way to do it
|
|
|
|
|
|
|
|
|
|
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
|
|
|
|
context = self.one_unet_in_gpu(unet = unet) if is_cuda and one_unet_in_gpu_at_time else null_context()
|
|
|
|
|
|
|
|
|
|
with context:
|
|
|
|
|
# prepare low resolution conditioning for upsamplers
|
|
|
|
|
@@ -3124,6 +3185,7 @@ class Decoder(nn.Module):
|
|
|
|
|
text_encodings = text_encodings,
|
|
|
|
|
cond_scale = unet_cond_scale,
|
|
|
|
|
predict_x_start = predict_x_start,
|
|
|
|
|
predict_v = predict_v,
|
|
|
|
|
learned_variance = learned_variance,
|
|
|
|
|
clip_denoised = not is_latent_diffusion,
|
|
|
|
|
lowres_cond_img = lowres_cond_img,
|
|
|
|
|
@@ -3163,6 +3225,7 @@ class Decoder(nn.Module):
|
|
|
|
|
lowres_conditioner = self.lowres_conds[unet_index]
|
|
|
|
|
target_image_size = self.image_sizes[unet_index]
|
|
|
|
|
predict_x_start = self.predict_x_start[unet_index]
|
|
|
|
|
predict_v = self.predict_v[unet_index]
|
|
|
|
|
random_crop_size = self.random_crop_sizes[unet_index]
|
|
|
|
|
learned_variance = self.learned_variance[unet_index]
|
|
|
|
|
b, c, h, w, device, = *image.shape, image.device
|
|
|
|
|
@@ -3201,7 +3264,7 @@ class Decoder(nn.Module):
|
|
|
|
|
image = vae.encode(image)
|
|
|
|
|
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
|
|
|
|
|
if not return_lowres_cond_image:
|
|
|
|
|
return losses
|
|
|
|
|
|