complete imagen-like noise level conditioning

This commit is contained in:
Phil Wang
2022-07-18 13:43:57 -07:00
parent c7fe4f2f44
commit 6afb886cf4
2 changed files with 144 additions and 37 deletions

View File

@@ -52,10 +52,10 @@ def first(arr, d = None):
def maybe(fn): def maybe(fn):
@wraps(fn) @wraps(fn)
def inner(x): def inner(x, *args, **kwargs):
if not exists(x): if not exists(x):
return x return x
return fn(x) return fn(x, *args, **kwargs)
return inner return inner
def default(val, d): def default(val, d):
@@ -63,13 +63,13 @@ def default(val, d):
return val return val
return d() if callable(d) else d return d() if callable(d) else d
def cast_tuple(val, length = None): def cast_tuple(val, length = None, validate = True):
if isinstance(val, list): if isinstance(val, list):
val = tuple(val) val = tuple(val)
out = val if isinstance(val, tuple) else ((val,) * default(length, 1)) out = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length): if exists(length) and validate:
assert len(out) == length assert len(out) == length
return out return out
@@ -494,6 +494,9 @@ class NoiseScheduler(nn.Module):
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0. self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
def sample_random_times(self, batch):
return torch.randint(0, self.num_timesteps, (batch,), device = self.betas.device, dtype = torch.long)
def q_posterior(self, x_start, x_t, t): def q_posterior(self, x_start, x_t, t):
posterior_mean = ( posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
@@ -1243,7 +1246,7 @@ class DiffusionPrior(nn.Module):
# timestep conditioning from ddpm # timestep conditioning from ddpm
batch, device = image_embed.shape[0], image_embed.device batch, device = image_embed.shape[0], image_embed.device
times = torch.randint(0, self.noise_scheduler.num_timesteps, (batch,), device = device, dtype = torch.long) times = self.noise_scheduler.sample_random_times(batch)
# scale image embed (Katherine) # scale image embed (Katherine)
@@ -1540,6 +1543,7 @@ class Unet(nn.Module):
attn_dim_head = 32, attn_dim_head = 32,
attn_heads = 16, attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
sparse_attn = False, sparse_attn = False,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False, cond_on_text_encodings = False,
@@ -1629,6 +1633,17 @@ class Unet(nn.Module):
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
self.text_embed_dim = text_embed_dim self.text_embed_dim = text_embed_dim
# low resolution noise conditiong, based on Imagen's upsampler training technique
self.lowres_noise_cond = lowres_noise_cond
self.to_lowres_noise_cond = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, time_cond_dim),
nn.GELU(),
nn.Linear(time_cond_dim, time_cond_dim)
) if lowres_noise_cond else None
# finer control over whether to condition on image embeddings and text encodings # finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting # so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
@@ -1745,15 +1760,17 @@ class Unet(nn.Module):
self, self,
*, *,
lowres_cond, lowres_cond,
lowres_noise_cond,
channels, channels,
channels_out, channels_out,
cond_on_image_embeds, cond_on_image_embeds,
cond_on_text_encodings cond_on_text_encodings,
): ):
if lowres_cond == self.lowres_cond and \ if lowres_cond == self.lowres_cond and \
channels == self.channels and \ channels == self.channels and \
cond_on_image_embeds == self.cond_on_image_embeds and \ cond_on_image_embeds == self.cond_on_image_embeds and \
cond_on_text_encodings == self.cond_on_text_encodings and \ cond_on_text_encodings == self.cond_on_text_encodings and \
cond_on_lowres_noise == self.cond_on_lowres_noise and \
channels_out == self.channels_out: channels_out == self.channels_out:
return self return self
@@ -1762,7 +1779,8 @@ class Unet(nn.Module):
channels = channels, channels = channels,
channels_out = channels_out, channels_out = channels_out,
cond_on_image_embeds = cond_on_image_embeds, cond_on_image_embeds = cond_on_image_embeds,
cond_on_text_encodings = cond_on_text_encodings cond_on_text_encodings = cond_on_text_encodings,
lowres_noise_cond = lowres_noise_cond
) )
return self.__class__(**{**self._locals, **updated_kwargs}) return self.__class__(**{**self._locals, **updated_kwargs})
@@ -1788,6 +1806,7 @@ class Unet(nn.Module):
*, *,
image_embed, image_embed,
lowres_cond_img = None, lowres_cond_img = None,
lowres_noise_level = None,
text_encodings = None, text_encodings = None,
image_cond_drop_prob = 0., image_cond_drop_prob = 0.,
text_cond_drop_prob = 0., text_cond_drop_prob = 0.,
@@ -1816,6 +1835,13 @@ class Unet(nn.Module):
time_tokens = self.to_time_tokens(time_hiddens) time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens) t = self.to_time_cond(time_hiddens)
# low res noise conditioning (similar to time above)
if exists(lowres_noise_level):
assert exists(self.to_lowres_noise_cond), 'lowres_noise_cond must be set to True on instantiation of the unet in order to conditiong on lowres noise'
lowres_noise_level = lowres_noise_level.type_as(x)
t = t + self.to_lowres_noise_cond(lowres_noise_level)
# conditional dropout # conditional dropout
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device) image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
@@ -1965,25 +1991,48 @@ class LowresConditioner(nn.Module):
def __init__( def __init__(
self, self,
downsample_first = True, downsample_first = True,
use_blur = True,
blur_prob = 0.5, blur_prob = 0.5,
blur_sigma = 0.6, blur_sigma = 0.6,
blur_kernel_size = 3, blur_kernel_size = 3,
input_image_range = None use_noise = False,
input_image_range = None,
normalize_img_fn = identity,
unnormalize_img_fn = identity
): ):
super().__init__() super().__init__()
self.downsample_first = downsample_first self.downsample_first = downsample_first
self.input_image_range = input_image_range self.input_image_range = input_image_range
self.use_blur = use_blur
self.blur_prob = blur_prob self.blur_prob = blur_prob
self.blur_sigma = blur_sigma self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size self.blur_kernel_size = blur_kernel_size
self.use_noise = use_noise
self.normalize_img = normalize_img_fn
self.unnormalize_img = unnormalize_img_fn
self.noise_scheduler = NoiseScheduler(beta_schedule = 'linear', timesteps = 1000, loss_type = 'l2') if use_noise else None
def noise_image(self, cond_fmap, noise_levels = None):
assert exists(self.noise_scheduler)
batch = cond_fmap.shape[0]
cond_fmap = self.normalize_img(cond_fmap)
random_noise_levels = default(noise_levels, lambda: self.noise_scheduler.sample_random_times(batch))
cond_fmap = self.noise_scheduler.q_sample(cond_fmap, t = random_noise_levels, noise = torch.randn_like(cond_fmap))
cond_fmap = self.unnormalize_img(cond_fmap)
return cond_fmap, random_noise_levels
def forward( def forward(
self, self,
cond_fmap, cond_fmap,
*, *,
target_image_size, target_image_size,
downsample_image_size = None, downsample_image_size = None,
should_blur = True,
blur_sigma = None, blur_sigma = None,
blur_kernel_size = None blur_kernel_size = None
): ):
@@ -1993,7 +2042,7 @@ class LowresConditioner(nn.Module):
# blur is only applied 50% of the time # blur is only applied 50% of the time
# section 3.1 in https://arxiv.org/abs/2106.15282 # section 3.1 in https://arxiv.org/abs/2106.15282
if random.random() < self.blur_prob: if self.use_blur and should_blur and random.random() < self.blur_prob:
# when training, blur the low resolution conditional image # when training, blur the low resolution conditional image
@@ -2015,8 +2064,21 @@ class LowresConditioner(nn.Module):
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2)) cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
# resize to target image size
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range, nearest = True) cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range, nearest = True)
return cond_fmap
# noise conditioning, as done in Imagen
# as a replacement for the BSR noising, and potentially replace blurring for first stage too
random_noise_levels = None
if self.use_noise:
cond_fmap, random_noise_levels = self.noise_image(cond_fmap)
# return conditioning feature map, as well as the augmentation noise levels
return cond_fmap, random_noise_levels
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(
@@ -2037,10 +2099,13 @@ class Decoder(nn.Module):
predict_x_start_for_latent_diffusion = False, predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops) random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
use_noise_for_lowres_cond = False, # whether to use Imagen-like noising for low resolution conditioning
use_blur_for_lowres_cond = True, # whether to use the blur conditioning used in the original cascading ddpm paper, as well as DALL-E2
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time
blur_sigma = 0.6, # cascading ddpm - blur sigma blur_sigma = 0.6, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size blur_kernel_size = 3, # cascading ddpm - blur kernel size
lowres_noise_sample_level = 0.2, # in imagen paper, they use a 0.2 noise level at sample time for low resolution conditioning
clip_denoised = True, clip_denoised = True,
clip_x_start = True, clip_x_start = True,
clip_adapter_overrides = dict(), clip_adapter_overrides = dict(),
@@ -2088,10 +2153,17 @@ class Decoder(nn.Module):
self.channels = channels self.channels = channels
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
# verify conditioning method # verify conditioning method
unets = cast_tuple(unet) unets = cast_tuple(unet)
num_unets = len(unets) num_unets = len(unets)
self.num_unets = num_unets
self.unconditional = unconditional self.unconditional = unconditional
@@ -2107,12 +2179,28 @@ class Decoder(nn.Module):
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1 self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
self.vb_loss_weight = vb_loss_weight self.vb_loss_weight = vb_loss_weight
# default and validate conditioning parameters
use_noise_for_lowres_cond = cast_tuple(use_noise_for_lowres_cond, num_unets - 1, validate = False)
use_blur_for_lowres_cond = cast_tuple(use_blur_for_lowres_cond, num_unets - 1, validate = False)
if len(use_noise_for_lowres_cond) < num_unets:
use_noise_for_lowres_cond = (False, *use_noise_for_lowres_cond)
if len(use_blur_for_lowres_cond) < num_unets:
use_blur_for_lowres_cond = (False, *use_blur_for_lowres_cond)
assert not use_noise_for_lowres_cond[0], 'first unet will never need low res noise conditioning'
assert not use_blur_for_lowres_cond[0], 'first unet will never need low res blur conditioning'
assert num_unets == 1 or all((use_noise or use_blur) for use_noise, use_blur in zip(use_noise_for_lowres_cond[1:], use_blur_for_lowres_cond[1:]))
# construct unets and vaes # construct unets and vaes
self.unets = nn.ModuleList([]) self.unets = nn.ModuleList([])
self.vaes = nn.ModuleList([]) self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)): for ind, (one_unet, one_vae, one_unet_learned_var, lowres_noise_cond) in enumerate(zip(unets, vaes, learned_variance, use_noise_for_lowres_cond)):
assert isinstance(one_unet, Unet) assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE)) assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
@@ -2124,6 +2212,7 @@ class Decoder(nn.Module):
one_unet = one_unet.cast_model_parameters( one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first, lowres_cond = not is_first,
lowres_noise_cond = lowres_noise_cond,
cond_on_image_embeds = not unconditional and is_first, cond_on_image_embeds = not unconditional and is_first,
cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings, cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings,
channels = unet_channels, channels = unet_channels,
@@ -2166,7 +2255,7 @@ class Decoder(nn.Module):
image_sizes = default(image_sizes, (image_size,)) image_sizes = default(image_sizes, (image_size,))
image_sizes = tuple(sorted(set(image_sizes))) image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}' assert self.num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({self.num_unets}) for resolutions {image_sizes}'
self.image_sizes = image_sizes self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes)) self.sample_channels = cast_tuple(self.channels, len(image_sizes))
@@ -2186,16 +2275,31 @@ class Decoder(nn.Module):
# cascading ddpm related stuff # cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.to_lowres_cond = LowresConditioner( self.lowres_conds = nn.ModuleList([])
for unet_index, use_noise, use_blur in zip(range(num_unets), use_noise_for_lowres_cond, use_blur_for_lowres_cond):
if unet_index == 0:
self.lowres_conds.append(None)
continue
lowres_cond = LowresConditioner(
downsample_first = lowres_downsample_first, downsample_first = lowres_downsample_first,
use_blur = use_blur,
use_noise = use_noise,
blur_prob = blur_prob, blur_prob = blur_prob,
blur_sigma = blur_sigma, blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size, blur_kernel_size = blur_kernel_size,
input_image_range = self.input_image_range input_image_range = self.input_image_range,
normalize_img_fn = self.normalize_img,
unnormalize_img_fn = self.unnormalize_img
) )
self.lowres_conds.append(lowres_cond)
self.lowres_noise_sample_level = lowres_noise_sample_level
# classifier free guidance # classifier free guidance
self.image_cond_drop_prob = image_cond_drop_prob self.image_cond_drop_prob = image_cond_drop_prob
@@ -2212,11 +2316,6 @@ class Decoder(nn.Module):
self.use_dynamic_thres = use_dynamic_thres self.use_dynamic_thres = use_dynamic_thres
self.dynamic_thres_percentile = dynamic_thres_percentile self.dynamic_thres_percentile = dynamic_thres_percentile
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
# device tracker # device tracker
self.register_buffer('_dummy', torch.Tensor([True]), persistent = False) self.register_buffer('_dummy', torch.Tensor([True]), persistent = False)
@@ -2230,7 +2329,7 @@ class Decoder(nn.Module):
return any([unet.cond_on_text_encodings for unet in self.unets]) return any([unet.cond_on_text_encodings for unet in self.unets])
def get_unet(self, unet_number): def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets) assert 0 < unet_number <= self.num_unets
index = unet_number - 1 index = unet_number - 1
return self.unets[index] return self.unets[index]
@@ -2316,7 +2415,7 @@ class Decoder(nn.Module):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False): def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None):
device = self.device device = self.device
b = shape[0] b = shape[0]
@@ -2334,6 +2433,7 @@ class Decoder(nn.Module):
text_encodings = text_encodings, text_encodings = text_encodings,
cond_scale = cond_scale, cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
noise_scheduler = noise_scheduler, noise_scheduler = noise_scheduler,
learned_variance = learned_variance, learned_variance = learned_variance,
@@ -2344,7 +2444,7 @@ class Decoder(nn.Module):
return unnormalize_img return unnormalize_img
@torch.no_grad() @torch.no_grad()
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False): def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1] times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
@@ -2363,7 +2463,7 @@ class Decoder(nn.Module):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long) time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
if learned_variance: if learned_variance:
pred, _ = pred.chunk(2, dim = 1) pred, _ = pred.chunk(2, dim = 1)
@@ -2402,7 +2502,7 @@ class Decoder(nn.Module):
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs) return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False): def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1] # normalize to [-1, 1]
@@ -2421,6 +2521,7 @@ class Decoder(nn.Module):
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
image_cond_drop_prob = self.image_cond_drop_prob, image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob, text_cond_drop_prob = self.text_cond_drop_prob,
) )
@@ -2500,20 +2601,24 @@ class Decoder(nn.Module):
img = None img = None
is_cuda = next(self.parameters()).is_cuda is_cuda = next(self.parameters()).is_cuda
num_unets = len(self.unets) num_unets = self.num_unets
cond_scale = cast_tuple(cond_scale, num_unets) cond_scale = cast_tuple(cond_scale, num_unets)
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps, cond_scale)): for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context() context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
with context: with context:
lowres_cond_img = None lowres_cond_img = lowres_noise_level = None
shape = (batch_size, channel, image_size, image_size) shape = (batch_size, channel, image_size, image_size)
if unet.lowres_cond: if unet.lowres_cond:
lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = True) lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = True)
if lowres_cond.use_noise:
lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device)
lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level)
is_latent_diffusion = isinstance(vae, VQGanVAE) is_latent_diffusion = isinstance(vae, VQGanVAE)
image_size = vae.get_encoded_fmap_size(image_size) image_size = vae.get_encoded_fmap_size(image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size) shape = (batch_size, vae.encoded_dim, image_size, image_size)
@@ -2530,6 +2635,7 @@ class Decoder(nn.Module):
learned_variance = learned_variance, learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion, clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
is_latent_diffusion = is_latent_diffusion, is_latent_diffusion = is_latent_diffusion,
noise_scheduler = noise_scheduler, noise_scheduler = noise_scheduler,
timesteps = sample_timesteps timesteps = sample_timesteps
@@ -2551,7 +2657,7 @@ class Decoder(nn.Module):
unet_number = None, unet_number = None,
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
): ):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' assert not (self.num_unets > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {self.num_unets}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1) unet_number = default(unet_number, 1)
unet_index = unet_number - 1 unet_index = unet_number - 1
@@ -2559,6 +2665,7 @@ class Decoder(nn.Module):
vae = self.vaes[unet_index] vae = self.vaes[unet_index]
noise_scheduler = self.noise_schedulers[unet_index] noise_scheduler = self.noise_schedulers[unet_index]
lowres_conditioner = self.lowres_conds[unet_index]
target_image_size = self.image_sizes[unet_index] target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index] predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index] random_crop_size = self.random_crop_sizes[unet_index]
@@ -2581,7 +2688,7 @@ class Decoder(nn.Module):
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None lowres_cond_img, lowres_noise_level = lowres_conditioner(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if exists(lowres_conditioner) else (None, None)
image = resize_image_to(image, target_image_size, nearest = True) image = resize_image_to(image, target_image_size, nearest = True)
if exists(random_crop_size): if exists(random_crop_size):
@@ -2599,7 +2706,7 @@ class Decoder(nn.Module):
image = vae.encode(image) image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img) lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler) losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
if not return_lowres_cond_image: if not return_lowres_cond_image:
return losses return losses

View File

@@ -1 +1 @@
__version__ = '0.24.3' __version__ = '0.25.0'