mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
allow for setting beta schedules of unets differently in the decoder, as what was used in the paper was cosine, cosine, linear
This commit is contained in:
@@ -378,7 +378,7 @@ def sigmoid_beta_schedule(timesteps):
|
||||
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||
|
||||
|
||||
class BaseGaussianDiffusion(nn.Module):
|
||||
class NoiseScheduler(nn.Module):
|
||||
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
|
||||
super().__init__()
|
||||
|
||||
@@ -472,11 +472,10 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
)
|
||||
|
||||
def sample(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
def p2_reweigh_loss(self, loss, times):
|
||||
if not self.has_p2_loss_reweighting:
|
||||
return loss
|
||||
return loss * extract(self.p2_loss_weight, times, loss.shape)
|
||||
|
||||
# diffusion prior
|
||||
|
||||
@@ -687,8 +686,7 @@ class Attention(nn.Module):
|
||||
|
||||
# attention
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
@@ -862,7 +860,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
return pred_image_embed
|
||||
|
||||
class DiffusionPrior(BaseGaussianDiffusion):
|
||||
class DiffusionPrior(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
net,
|
||||
@@ -883,7 +881,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
clip_adapter_overrides = dict()
|
||||
):
|
||||
super().__init__(
|
||||
super().__init__()
|
||||
|
||||
self.noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = beta_schedule,
|
||||
timesteps = timesteps,
|
||||
loss_type = loss_type
|
||||
@@ -923,6 +923,13 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
self.training_clamp_l2norm = training_clamp_l2norm
|
||||
self.init_image_embed_l2norm = init_image_embed_l2norm
|
||||
|
||||
# device tracker
|
||||
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._dummy.device
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
@@ -933,7 +940,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
||||
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
||||
else:
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised and not self.predict_x_start:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
@@ -941,7 +948,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_recon = l2norm(x_recon) * self.image_embed_scale
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -955,7 +962,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
|
||||
device = self.betas.device
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
image_embed = torch.randn(shape, device=device)
|
||||
@@ -963,7 +970,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
@@ -972,7 +979,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
|
||||
image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||
|
||||
pred = self.net(
|
||||
image_embed_noisy,
|
||||
@@ -986,7 +993,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
target = noise if not self.predict_x_start else image_embed
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
loss = self.noise_scheduler.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -997,7 +1004,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = self.noise_scheduler.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
|
||||
return img
|
||||
|
||||
@@ -1069,7 +1076,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
# timestep conditioning from ddpm
|
||||
|
||||
batch, device = image_embed.shape[0], image_embed.device
|
||||
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
|
||||
times = torch.randint(0, self.noise_scheduler.num_timesteps, (batch,), device = device, dtype = torch.long)
|
||||
|
||||
# scale image embed (Katherine)
|
||||
|
||||
@@ -1234,8 +1241,7 @@ class CrossAttention(nn.Module):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
@@ -1739,7 +1745,7 @@ class LowresConditioner(nn.Module):
|
||||
|
||||
return cond_fmap
|
||||
|
||||
class Decoder(BaseGaussianDiffusion):
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
unet,
|
||||
@@ -1752,7 +1758,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
loss_type = 'l2',
|
||||
beta_schedule = 'cosine',
|
||||
beta_schedule = None,
|
||||
predict_x_start = False,
|
||||
predict_x_start_for_latent_diffusion = False,
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
@@ -1774,13 +1780,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
||||
p2_loss_weight_k = 1
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
timesteps = timesteps,
|
||||
loss_type = loss_type,
|
||||
p2_loss_weight_gamma = p2_loss_weight_gamma,
|
||||
p2_loss_weight_k = p2_loss_weight_k
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.unconditional = unconditional
|
||||
|
||||
@@ -1824,6 +1824,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
unets = cast_tuple(unet)
|
||||
num_unets = len(unets)
|
||||
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||
|
||||
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
|
||||
@@ -1859,6 +1861,24 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# create noise schedulers per unet
|
||||
|
||||
if not exists(beta_schedule):
|
||||
beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))
|
||||
|
||||
self.noise_schedulers = nn.ModuleList([])
|
||||
|
||||
for unet_beta_schedule in beta_schedule:
|
||||
noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = unet_beta_schedule,
|
||||
timesteps = timesteps,
|
||||
loss_type = loss_type,
|
||||
p2_loss_weight_gamma = p2_loss_weight_gamma,
|
||||
p2_loss_weight_k = p2_loss_weight_k
|
||||
)
|
||||
|
||||
self.noise_schedulers.append(noise_scheduler)
|
||||
|
||||
# unet image sizes
|
||||
|
||||
image_sizes = default(image_sizes, (image_size,))
|
||||
@@ -1908,6 +1928,14 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||
|
||||
# device tracker
|
||||
|
||||
self.register_buffer('_dummy', torch.Tensor([True]), persistent = False)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._dummy.device
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
@@ -1931,7 +1959,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
for unet, device in zip(self.unets, devices):
|
||||
unet.to(device)
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
||||
@@ -1942,7 +1970,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
else:
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
# s is the threshold amount
|
||||
@@ -1961,14 +1989,14 @@ class Decoder(BaseGaussianDiffusion):
|
||||
# clip by threshold, depending on whether static or dynamic
|
||||
x_recon = x_recon.clamp(-s, s) / s
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
|
||||
if learned_variance:
|
||||
# if learned variance, posterio variance and posterior log variance are predicted by the network
|
||||
# by an interpolation of the max and min log beta values
|
||||
# eq 15 - https://arxiv.org/abs/2102.09672
|
||||
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = extract(torch.log(self.betas), t, x.shape)
|
||||
min_log = extract(noise_scheduler.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = extract(torch.log(noise_scheduler.betas), t, x.shape)
|
||||
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
||||
|
||||
if self.learned_variance_constrain_frac:
|
||||
@@ -1980,17 +2008,17 @@ class Decoder(BaseGaussianDiffusion):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
device = self.betas.device
|
||||
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
@@ -1998,7 +2026,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if not is_latent_diffusion:
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
|
||||
img = self.p_sample(
|
||||
unet,
|
||||
img,
|
||||
@@ -2009,6 +2037,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
predict_x_start = predict_x_start,
|
||||
noise_scheduler = noise_scheduler,
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = clip_denoised
|
||||
)
|
||||
@@ -2016,7 +2045,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
unnormalize_img = self.unnormalize_img(img)
|
||||
return unnormalize_img
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
# normalize to [-1, 1]
|
||||
@@ -2027,7 +2056,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
# get x_t
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
|
||||
model_output = unet(
|
||||
x_noisy,
|
||||
@@ -2047,11 +2076,10 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
loss = self.loss_fn(pred, target, reduction = 'none')
|
||||
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
|
||||
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
||||
|
||||
if self.has_p2_loss_reweighting:
|
||||
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
|
||||
loss = noise_scheduler.p2_reweigh_loss(loss, times)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -2066,8 +2094,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
# if learning the variance, also include the extra weight kl loss
|
||||
|
||||
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
|
||||
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||
|
||||
@@ -2117,7 +2145,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
||||
|
||||
@@ -2145,7 +2173,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = not is_latent_diffusion,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
is_latent_diffusion = is_latent_diffusion
|
||||
is_latent_diffusion = is_latent_diffusion,
|
||||
noise_scheduler = noise_scheduler
|
||||
)
|
||||
|
||||
img = vae.decode(img)
|
||||
@@ -2171,6 +2200,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
vae = self.vaes[unet_index]
|
||||
noise_scheduler = self.noise_schedulers[unet_index]
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
random_crop_size = self.random_crop_sizes[unet_index]
|
||||
@@ -2180,7 +2210,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
check_shape(image, 'b c h w', c = self.channels)
|
||||
assert h >= target_image_size and w >= target_image_size
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
|
||||
if not exists(image_embed) and not self.unconditional:
|
||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||
@@ -2211,7 +2241,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
image = vae.encode(image)
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion)
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
Reference in New Issue
Block a user