mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 01:34:19 +01:00
complete imagen-like noise level conditioning
This commit is contained in:
@@ -52,10 +52,10 @@ def first(arr, d = None):
|
||||
|
||||
def maybe(fn):
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
def inner(x, *args, **kwargs):
|
||||
if not exists(x):
|
||||
return x
|
||||
return fn(x)
|
||||
return fn(x, *args, **kwargs)
|
||||
return inner
|
||||
|
||||
def default(val, d):
|
||||
@@ -63,13 +63,13 @@ def default(val, d):
|
||||
return val
|
||||
return d() if callable(d) else d
|
||||
|
||||
def cast_tuple(val, length = None):
|
||||
def cast_tuple(val, length = None, validate = True):
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
|
||||
out = val if isinstance(val, tuple) else ((val,) * default(length, 1))
|
||||
|
||||
if exists(length):
|
||||
if exists(length) and validate:
|
||||
assert len(out) == length
|
||||
|
||||
return out
|
||||
@@ -494,6 +494,9 @@ class NoiseScheduler(nn.Module):
|
||||
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
|
||||
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
|
||||
|
||||
def sample_random_times(self, batch):
|
||||
return torch.randint(0, self.num_timesteps, (batch,), device = self.betas.device, dtype = torch.long)
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||
@@ -1243,7 +1246,7 @@ class DiffusionPrior(nn.Module):
|
||||
# timestep conditioning from ddpm
|
||||
|
||||
batch, device = image_embed.shape[0], image_embed.device
|
||||
times = torch.randint(0, self.noise_scheduler.num_timesteps, (batch,), device = device, dtype = torch.long)
|
||||
times = self.noise_scheduler.sample_random_times(batch)
|
||||
|
||||
# scale image embed (Katherine)
|
||||
|
||||
@@ -1540,6 +1543,7 @@ class Unet(nn.Module):
|
||||
attn_dim_head = 32,
|
||||
attn_heads = 16,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
|
||||
sparse_attn = False,
|
||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||
cond_on_text_encodings = False,
|
||||
@@ -1629,6 +1633,17 @@ class Unet(nn.Module):
|
||||
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
||||
self.text_embed_dim = text_embed_dim
|
||||
|
||||
# low resolution noise conditiong, based on Imagen's upsampler training technique
|
||||
|
||||
self.lowres_noise_cond = lowres_noise_cond
|
||||
|
||||
self.to_lowres_noise_cond = nn.Sequential(
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, time_cond_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(time_cond_dim, time_cond_dim)
|
||||
) if lowres_noise_cond else None
|
||||
|
||||
# finer control over whether to condition on image embeddings and text encodings
|
||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||
|
||||
@@ -1745,15 +1760,17 @@ class Unet(nn.Module):
|
||||
self,
|
||||
*,
|
||||
lowres_cond,
|
||||
lowres_noise_cond,
|
||||
channels,
|
||||
channels_out,
|
||||
cond_on_image_embeds,
|
||||
cond_on_text_encodings
|
||||
cond_on_text_encodings,
|
||||
):
|
||||
if lowres_cond == self.lowres_cond and \
|
||||
channels == self.channels and \
|
||||
cond_on_image_embeds == self.cond_on_image_embeds and \
|
||||
cond_on_text_encodings == self.cond_on_text_encodings and \
|
||||
cond_on_lowres_noise == self.cond_on_lowres_noise and \
|
||||
channels_out == self.channels_out:
|
||||
return self
|
||||
|
||||
@@ -1762,7 +1779,8 @@ class Unet(nn.Module):
|
||||
channels = channels,
|
||||
channels_out = channels_out,
|
||||
cond_on_image_embeds = cond_on_image_embeds,
|
||||
cond_on_text_encodings = cond_on_text_encodings
|
||||
cond_on_text_encodings = cond_on_text_encodings,
|
||||
lowres_noise_cond = lowres_noise_cond
|
||||
)
|
||||
|
||||
return self.__class__(**{**self._locals, **updated_kwargs})
|
||||
@@ -1788,6 +1806,7 @@ class Unet(nn.Module):
|
||||
*,
|
||||
image_embed,
|
||||
lowres_cond_img = None,
|
||||
lowres_noise_level = None,
|
||||
text_encodings = None,
|
||||
image_cond_drop_prob = 0.,
|
||||
text_cond_drop_prob = 0.,
|
||||
@@ -1816,6 +1835,13 @@ class Unet(nn.Module):
|
||||
time_tokens = self.to_time_tokens(time_hiddens)
|
||||
t = self.to_time_cond(time_hiddens)
|
||||
|
||||
# low res noise conditioning (similar to time above)
|
||||
|
||||
if exists(lowres_noise_level):
|
||||
assert exists(self.to_lowres_noise_cond), 'lowres_noise_cond must be set to True on instantiation of the unet in order to conditiong on lowres noise'
|
||||
lowres_noise_level = lowres_noise_level.type_as(x)
|
||||
t = t + self.to_lowres_noise_cond(lowres_noise_level)
|
||||
|
||||
# conditional dropout
|
||||
|
||||
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||
@@ -1965,25 +1991,48 @@ class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
downsample_first = True,
|
||||
use_blur = True,
|
||||
blur_prob = 0.5,
|
||||
blur_sigma = 0.6,
|
||||
blur_kernel_size = 3,
|
||||
input_image_range = None
|
||||
use_noise = False,
|
||||
input_image_range = None,
|
||||
normalize_img_fn = identity,
|
||||
unnormalize_img_fn = identity
|
||||
):
|
||||
super().__init__()
|
||||
self.downsample_first = downsample_first
|
||||
self.input_image_range = input_image_range
|
||||
|
||||
self.use_blur = use_blur
|
||||
self.blur_prob = blur_prob
|
||||
self.blur_sigma = blur_sigma
|
||||
self.blur_kernel_size = blur_kernel_size
|
||||
|
||||
self.use_noise = use_noise
|
||||
self.normalize_img = normalize_img_fn
|
||||
self.unnormalize_img = unnormalize_img_fn
|
||||
self.noise_scheduler = NoiseScheduler(beta_schedule = 'linear', timesteps = 1000, loss_type = 'l2') if use_noise else None
|
||||
|
||||
def noise_image(self, cond_fmap, noise_levels = None):
|
||||
assert exists(self.noise_scheduler)
|
||||
|
||||
batch = cond_fmap.shape[0]
|
||||
cond_fmap = self.normalize_img(cond_fmap)
|
||||
|
||||
random_noise_levels = default(noise_levels, lambda: self.noise_scheduler.sample_random_times(batch))
|
||||
cond_fmap = self.noise_scheduler.q_sample(cond_fmap, t = random_noise_levels, noise = torch.randn_like(cond_fmap))
|
||||
|
||||
cond_fmap = self.unnormalize_img(cond_fmap)
|
||||
return cond_fmap, random_noise_levels
|
||||
|
||||
def forward(
|
||||
self,
|
||||
cond_fmap,
|
||||
*,
|
||||
target_image_size,
|
||||
downsample_image_size = None,
|
||||
should_blur = True,
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None
|
||||
):
|
||||
@@ -1993,7 +2042,7 @@ class LowresConditioner(nn.Module):
|
||||
# blur is only applied 50% of the time
|
||||
# section 3.1 in https://arxiv.org/abs/2106.15282
|
||||
|
||||
if random.random() < self.blur_prob:
|
||||
if self.use_blur and should_blur and random.random() < self.blur_prob:
|
||||
|
||||
# when training, blur the low resolution conditional image
|
||||
|
||||
@@ -2015,8 +2064,21 @@ class LowresConditioner(nn.Module):
|
||||
|
||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
# resize to target image size
|
||||
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range, nearest = True)
|
||||
return cond_fmap
|
||||
|
||||
# noise conditioning, as done in Imagen
|
||||
# as a replacement for the BSR noising, and potentially replace blurring for first stage too
|
||||
|
||||
random_noise_levels = None
|
||||
|
||||
if self.use_noise:
|
||||
cond_fmap, random_noise_levels = self.noise_image(cond_fmap)
|
||||
|
||||
# return conditioning feature map, as well as the augmentation noise levels
|
||||
|
||||
return cond_fmap, random_noise_levels
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
@@ -2037,10 +2099,13 @@ class Decoder(nn.Module):
|
||||
predict_x_start_for_latent_diffusion = False,
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||
use_noise_for_lowres_cond = False, # whether to use Imagen-like noising for low resolution conditioning
|
||||
use_blur_for_lowres_cond = True, # whether to use the blur conditioning used in the original cascading ddpm paper, as well as DALL-E2
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time
|
||||
blur_sigma = 0.6, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
lowres_noise_sample_level = 0.2, # in imagen paper, they use a 0.2 noise level at sample time for low resolution conditioning
|
||||
clip_denoised = True,
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict(),
|
||||
@@ -2088,10 +2153,17 @@ class Decoder(nn.Module):
|
||||
|
||||
self.channels = channels
|
||||
|
||||
|
||||
# normalize and unnormalize image functions
|
||||
|
||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||
|
||||
# verify conditioning method
|
||||
|
||||
unets = cast_tuple(unet)
|
||||
num_unets = len(unets)
|
||||
self.num_unets = num_unets
|
||||
|
||||
self.unconditional = unconditional
|
||||
|
||||
@@ -2107,12 +2179,28 @@ class Decoder(nn.Module):
|
||||
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
|
||||
self.vb_loss_weight = vb_loss_weight
|
||||
|
||||
# default and validate conditioning parameters
|
||||
|
||||
use_noise_for_lowres_cond = cast_tuple(use_noise_for_lowres_cond, num_unets - 1, validate = False)
|
||||
use_blur_for_lowres_cond = cast_tuple(use_blur_for_lowres_cond, num_unets - 1, validate = False)
|
||||
|
||||
if len(use_noise_for_lowres_cond) < num_unets:
|
||||
use_noise_for_lowres_cond = (False, *use_noise_for_lowres_cond)
|
||||
|
||||
if len(use_blur_for_lowres_cond) < num_unets:
|
||||
use_blur_for_lowres_cond = (False, *use_blur_for_lowres_cond)
|
||||
|
||||
assert not use_noise_for_lowres_cond[0], 'first unet will never need low res noise conditioning'
|
||||
assert not use_blur_for_lowres_cond[0], 'first unet will never need low res blur conditioning'
|
||||
|
||||
assert num_unets == 1 or all((use_noise or use_blur) for use_noise, use_blur in zip(use_noise_for_lowres_cond[1:], use_blur_for_lowres_cond[1:]))
|
||||
|
||||
# construct unets and vaes
|
||||
|
||||
self.unets = nn.ModuleList([])
|
||||
self.vaes = nn.ModuleList([])
|
||||
|
||||
for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)):
|
||||
for ind, (one_unet, one_vae, one_unet_learned_var, lowres_noise_cond) in enumerate(zip(unets, vaes, learned_variance, use_noise_for_lowres_cond)):
|
||||
assert isinstance(one_unet, Unet)
|
||||
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
|
||||
|
||||
@@ -2124,6 +2212,7 @@ class Decoder(nn.Module):
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
lowres_noise_cond = lowres_noise_cond,
|
||||
cond_on_image_embeds = not unconditional and is_first,
|
||||
cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings,
|
||||
channels = unet_channels,
|
||||
@@ -2166,7 +2255,7 @@ class Decoder(nn.Module):
|
||||
image_sizes = default(image_sizes, (image_size,))
|
||||
image_sizes = tuple(sorted(set(image_sizes)))
|
||||
|
||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||
assert self.num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({self.num_unets}) for resolutions {image_sizes}'
|
||||
self.image_sizes = image_sizes
|
||||
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
|
||||
|
||||
@@ -2186,15 +2275,30 @@ class Decoder(nn.Module):
|
||||
# cascading ddpm related stuff
|
||||
|
||||
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
||||
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
||||
assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
||||
|
||||
self.to_lowres_cond = LowresConditioner(
|
||||
downsample_first = lowres_downsample_first,
|
||||
blur_prob = blur_prob,
|
||||
blur_sigma = blur_sigma,
|
||||
blur_kernel_size = blur_kernel_size,
|
||||
input_image_range = self.input_image_range
|
||||
)
|
||||
self.lowres_conds = nn.ModuleList([])
|
||||
|
||||
for unet_index, use_noise, use_blur in zip(range(num_unets), use_noise_for_lowres_cond, use_blur_for_lowres_cond):
|
||||
if unet_index == 0:
|
||||
self.lowres_conds.append(None)
|
||||
continue
|
||||
|
||||
lowres_cond = LowresConditioner(
|
||||
downsample_first = lowres_downsample_first,
|
||||
use_blur = use_blur,
|
||||
use_noise = use_noise,
|
||||
blur_prob = blur_prob,
|
||||
blur_sigma = blur_sigma,
|
||||
blur_kernel_size = blur_kernel_size,
|
||||
input_image_range = self.input_image_range,
|
||||
normalize_img_fn = self.normalize_img,
|
||||
unnormalize_img_fn = self.unnormalize_img
|
||||
)
|
||||
|
||||
self.lowres_conds.append(lowres_cond)
|
||||
|
||||
self.lowres_noise_sample_level = lowres_noise_sample_level
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
@@ -2212,11 +2316,6 @@ class Decoder(nn.Module):
|
||||
self.use_dynamic_thres = use_dynamic_thres
|
||||
self.dynamic_thres_percentile = dynamic_thres_percentile
|
||||
|
||||
# normalize and unnormalize image functions
|
||||
|
||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||
|
||||
# device tracker
|
||||
|
||||
self.register_buffer('_dummy', torch.Tensor([True]), persistent = False)
|
||||
@@ -2230,7 +2329,7 @@ class Decoder(nn.Module):
|
||||
return any([unet.cond_on_text_encodings for unet in self.unets])
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
assert 0 < unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
return self.unets[index]
|
||||
|
||||
@@ -2316,7 +2415,7 @@ class Decoder(nn.Module):
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None):
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
@@ -2334,6 +2433,7 @@ class Decoder(nn.Module):
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
lowres_noise_level = lowres_noise_level,
|
||||
predict_x_start = predict_x_start,
|
||||
noise_scheduler = noise_scheduler,
|
||||
learned_variance = learned_variance,
|
||||
@@ -2344,7 +2444,7 @@ class Decoder(nn.Module):
|
||||
return unnormalize_img
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None):
|
||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
@@ -2363,7 +2463,7 @@ class Decoder(nn.Module):
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = pred.chunk(2, dim = 1)
|
||||
@@ -2402,7 +2502,7 @@ class Decoder(nn.Module):
|
||||
|
||||
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
# normalize to [-1, 1]
|
||||
@@ -2421,6 +2521,7 @@ class Decoder(nn.Module):
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
lowres_noise_level = lowres_noise_level,
|
||||
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
)
|
||||
@@ -2500,20 +2601,24 @@ class Decoder(nn.Module):
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
num_unets = len(self.unets)
|
||||
num_unets = self.num_unets
|
||||
cond_scale = cast_tuple(cond_scale, num_unets)
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps, cond_scale)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
||||
|
||||
with context:
|
||||
lowres_cond_img = None
|
||||
lowres_cond_img = lowres_noise_level = None
|
||||
shape = (batch_size, channel, image_size, image_size)
|
||||
|
||||
if unet.lowres_cond:
|
||||
lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = True)
|
||||
|
||||
if lowres_cond.use_noise:
|
||||
lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device)
|
||||
lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level)
|
||||
|
||||
is_latent_diffusion = isinstance(vae, VQGanVAE)
|
||||
image_size = vae.get_encoded_fmap_size(image_size)
|
||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||
@@ -2530,6 +2635,7 @@ class Decoder(nn.Module):
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = not is_latent_diffusion,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
lowres_noise_level = lowres_noise_level,
|
||||
is_latent_diffusion = is_latent_diffusion,
|
||||
noise_scheduler = noise_scheduler,
|
||||
timesteps = sample_timesteps
|
||||
@@ -2551,7 +2657,7 @@ class Decoder(nn.Module):
|
||||
unet_number = None,
|
||||
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
|
||||
):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
assert not (self.num_unets > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {self.num_unets}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
unet_index = unet_number - 1
|
||||
|
||||
@@ -2559,6 +2665,7 @@ class Decoder(nn.Module):
|
||||
|
||||
vae = self.vaes[unet_index]
|
||||
noise_scheduler = self.noise_schedulers[unet_index]
|
||||
lowres_conditioner = self.lowres_conds[unet_index]
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
random_crop_size = self.random_crop_sizes[unet_index]
|
||||
@@ -2581,7 +2688,7 @@ class Decoder(nn.Module):
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
lowres_cond_img, lowres_noise_level = lowres_conditioner(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if exists(lowres_conditioner) else (None, None)
|
||||
image = resize_image_to(image, target_image_size, nearest = True)
|
||||
|
||||
if exists(random_crop_size):
|
||||
@@ -2599,7 +2706,7 @@ class Decoder(nn.Module):
|
||||
image = vae.encode(image)
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
|
||||
|
||||
if not return_lowres_cond_image:
|
||||
return losses
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.24.3'
|
||||
__version__ = '0.25.0'
|
||||
|
||||
Reference in New Issue
Block a user