|
|
|
|
@@ -52,10 +52,10 @@ def first(arr, d = None):
|
|
|
|
|
|
|
|
|
|
def maybe(fn):
|
|
|
|
|
@wraps(fn)
|
|
|
|
|
def inner(x):
|
|
|
|
|
def inner(x, *args, **kwargs):
|
|
|
|
|
if not exists(x):
|
|
|
|
|
return x
|
|
|
|
|
return fn(x)
|
|
|
|
|
return fn(x, *args, **kwargs)
|
|
|
|
|
return inner
|
|
|
|
|
|
|
|
|
|
def default(val, d):
|
|
|
|
|
@@ -63,13 +63,13 @@ def default(val, d):
|
|
|
|
|
return val
|
|
|
|
|
return d() if callable(d) else d
|
|
|
|
|
|
|
|
|
|
def cast_tuple(val, length = None):
|
|
|
|
|
def cast_tuple(val, length = None, validate = True):
|
|
|
|
|
if isinstance(val, list):
|
|
|
|
|
val = tuple(val)
|
|
|
|
|
|
|
|
|
|
out = val if isinstance(val, tuple) else ((val,) * default(length, 1))
|
|
|
|
|
|
|
|
|
|
if exists(length):
|
|
|
|
|
if exists(length) and validate:
|
|
|
|
|
assert len(out) == length
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
@@ -77,6 +77,11 @@ def cast_tuple(val, length = None):
|
|
|
|
|
def module_device(module):
|
|
|
|
|
return next(module.parameters()).device
|
|
|
|
|
|
|
|
|
|
def zero_init_(m):
|
|
|
|
|
nn.init.zeros_(m.weight)
|
|
|
|
|
if exists(m.bias):
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def null_context(*args, **kwargs):
|
|
|
|
|
yield
|
|
|
|
|
@@ -141,7 +146,7 @@ def resize_image_to(
|
|
|
|
|
scale_factors = target_image_size / orig_image_size
|
|
|
|
|
out = resize(image, scale_factors = scale_factors, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
out = F.interpolate(image, target_image_size, mode = 'nearest', align_corners = False)
|
|
|
|
|
out = F.interpolate(image, target_image_size, mode = 'nearest')
|
|
|
|
|
|
|
|
|
|
if exists(clamp_range):
|
|
|
|
|
out = out.clamp(*clamp_range)
|
|
|
|
|
@@ -160,7 +165,7 @@ def unnormalize_zero_to_one(normed_img):
|
|
|
|
|
|
|
|
|
|
# clip related adapters
|
|
|
|
|
|
|
|
|
|
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
|
|
|
|
|
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings'])
|
|
|
|
|
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
|
|
|
|
|
|
|
|
|
class BaseClipAdapter(nn.Module):
|
|
|
|
|
@@ -221,7 +226,7 @@ class XClipAdapter(BaseClipAdapter):
|
|
|
|
|
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
|
|
|
|
text_embed = self.clip.to_text_latent(text_cls)
|
|
|
|
|
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
|
|
|
|
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
|
|
|
|
|
return EmbeddedText(l2norm(text_embed), text_encodings)
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def embed_image(self, image):
|
|
|
|
|
@@ -257,7 +262,7 @@ class CoCaAdapter(BaseClipAdapter):
|
|
|
|
|
text_mask = text != 0
|
|
|
|
|
text_embed, text_encodings = self.clip.embed_text(text)
|
|
|
|
|
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
|
|
|
|
return EmbeddedText(text_embed, text_encodings, text_mask)
|
|
|
|
|
return EmbeddedText(text_embed, text_encodings)
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def embed_image(self, image):
|
|
|
|
|
@@ -273,6 +278,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|
|
|
|
import clip
|
|
|
|
|
openai_clip, preprocess = clip.load(name)
|
|
|
|
|
super().__init__(openai_clip)
|
|
|
|
|
self.eos_id = 49407 # for handling 0 being also '!'
|
|
|
|
|
|
|
|
|
|
text_attention_final = self.find_layer('ln_final')
|
|
|
|
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
|
|
|
|
@@ -311,14 +317,17 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def embed_text(self, text):
|
|
|
|
|
text = text[..., :self.max_text_len]
|
|
|
|
|
text_mask = text != 0
|
|
|
|
|
|
|
|
|
|
is_eos_id = (text == self.eos_id)
|
|
|
|
|
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
|
|
|
|
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
|
|
|
|
assert not self.cleared
|
|
|
|
|
|
|
|
|
|
text_embed = self.clip.encode_text(text)
|
|
|
|
|
text_encodings = self.text_encodings
|
|
|
|
|
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
|
|
|
|
del self.text_encodings
|
|
|
|
|
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
|
|
|
|
|
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def embed_image(self, image):
|
|
|
|
|
@@ -485,6 +494,9 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
|
|
|
|
|
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
|
|
|
|
|
|
|
|
|
|
def sample_random_times(self, batch):
|
|
|
|
|
return torch.randint(0, self.num_timesteps, (batch,), device = self.betas.device, dtype = torch.long)
|
|
|
|
|
|
|
|
|
|
def q_posterior(self, x_start, x_t, t):
|
|
|
|
|
posterior_mean = (
|
|
|
|
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
|
|
|
@@ -510,7 +522,7 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
|
|
|
|
|
def predict_noise_from_start(self, x_t, t, x0):
|
|
|
|
|
return (
|
|
|
|
|
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
|
|
|
|
|
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
|
|
|
|
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -522,25 +534,31 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
# diffusion prior
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module):
|
|
|
|
|
def __init__(self, dim, eps = 1e-5):
|
|
|
|
|
def __init__(self, dim, eps = 1e-5, stable = False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.eps = eps
|
|
|
|
|
self.stable = stable
|
|
|
|
|
self.g = nn.Parameter(torch.ones(dim))
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = x / x.amax(dim = -1, keepdim = True).detach()
|
|
|
|
|
if self.stable:
|
|
|
|
|
x = x / x.amax(dim = -1, keepdim = True).detach()
|
|
|
|
|
|
|
|
|
|
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
|
|
|
|
mean = torch.mean(x, dim = -1, keepdim = True)
|
|
|
|
|
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
|
|
|
|
|
|
|
|
|
class ChanLayerNorm(nn.Module):
|
|
|
|
|
def __init__(self, dim, eps = 1e-5):
|
|
|
|
|
def __init__(self, dim, eps = 1e-5, stable = False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.eps = eps
|
|
|
|
|
self.stable = stable
|
|
|
|
|
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = x / x.amax(dim = 1, keepdim = True).detach()
|
|
|
|
|
if self.stable:
|
|
|
|
|
x = x / x.amax(dim = 1, keepdim = True).detach()
|
|
|
|
|
|
|
|
|
|
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
|
|
|
|
mean = torch.mean(x, dim = 1, keepdim = True)
|
|
|
|
|
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
|
|
|
|
@@ -664,7 +682,7 @@ class Attention(nn.Module):
|
|
|
|
|
dropout = 0.,
|
|
|
|
|
causal = False,
|
|
|
|
|
rotary_emb = None,
|
|
|
|
|
pb_relax_alpha = 32 ** 2
|
|
|
|
|
pb_relax_alpha = 128
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.pb_relax_alpha = pb_relax_alpha
|
|
|
|
|
@@ -755,6 +773,7 @@ class CausalTransformer(nn.Module):
|
|
|
|
|
dim_head = 64,
|
|
|
|
|
heads = 8,
|
|
|
|
|
ff_mult = 4,
|
|
|
|
|
norm_in = False,
|
|
|
|
|
norm_out = True,
|
|
|
|
|
attn_dropout = 0.,
|
|
|
|
|
ff_dropout = 0.,
|
|
|
|
|
@@ -763,6 +782,8 @@ class CausalTransformer(nn.Module):
|
|
|
|
|
rotary_emb = True
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
|
|
|
|
|
|
|
|
|
|
self.rel_pos_bias = RelPosBias(heads = heads)
|
|
|
|
|
|
|
|
|
|
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
|
|
|
|
@@ -774,20 +795,18 @@ class CausalTransformer(nn.Module):
|
|
|
|
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
|
|
|
|
self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
|
|
|
|
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
x,
|
|
|
|
|
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
|
|
|
|
):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
n, device = x.shape[1], x.device
|
|
|
|
|
|
|
|
|
|
x = self.init_norm(x)
|
|
|
|
|
|
|
|
|
|
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
|
|
|
|
|
|
|
|
|
for attn, ff in self.layers:
|
|
|
|
|
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
|
|
|
|
x = attn(x, attn_bias = attn_bias) + x
|
|
|
|
|
x = ff(x) + x
|
|
|
|
|
|
|
|
|
|
out = self.norm(x)
|
|
|
|
|
@@ -801,6 +820,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
num_time_embeds = 1,
|
|
|
|
|
num_image_embeds = 1,
|
|
|
|
|
num_text_embeds = 1,
|
|
|
|
|
max_text_len = 256,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@@ -826,6 +846,11 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
self.learned_query = nn.Parameter(torch.randn(dim))
|
|
|
|
|
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
|
|
|
|
|
|
|
|
|
# dalle1 learned padding strategy
|
|
|
|
|
|
|
|
|
|
self.max_text_len = max_text_len
|
|
|
|
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
|
|
|
|
|
|
|
|
|
def forward_with_cond_scale(
|
|
|
|
|
self,
|
|
|
|
|
*args,
|
|
|
|
|
@@ -847,7 +872,6 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
*,
|
|
|
|
|
text_embed,
|
|
|
|
|
text_encodings = None,
|
|
|
|
|
mask = None,
|
|
|
|
|
cond_drop_prob = 0.
|
|
|
|
|
):
|
|
|
|
|
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
|
|
|
|
@@ -865,9 +889,28 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
|
|
|
|
|
if not exists(text_encodings):
|
|
|
|
|
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
|
|
|
|
|
|
|
|
|
mask = torch.any(text_encodings != 0., dim = -1)
|
|
|
|
|
|
|
|
|
|
if not exists(mask):
|
|
|
|
|
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
|
|
|
|
|
# replace any padding in the text encodings with learned padding tokens unique across position
|
|
|
|
|
|
|
|
|
|
text_encodings = text_encodings[:, :self.max_text_len]
|
|
|
|
|
mask = mask[:, :self.max_text_len]
|
|
|
|
|
|
|
|
|
|
text_len = text_encodings.shape[-2]
|
|
|
|
|
remainder = self.max_text_len - text_len
|
|
|
|
|
|
|
|
|
|
if remainder > 0:
|
|
|
|
|
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
|
|
|
|
|
mask = F.pad(mask, (0, remainder), value = False)
|
|
|
|
|
|
|
|
|
|
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
|
|
|
|
|
|
|
|
|
|
text_encodings = torch.where(
|
|
|
|
|
rearrange(mask, 'b n -> b n 1').clone(),
|
|
|
|
|
text_encodings,
|
|
|
|
|
null_text_embeds
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# classifier free guidance
|
|
|
|
|
|
|
|
|
|
@@ -884,9 +927,8 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
|
|
|
|
# but let's just do it right
|
|
|
|
|
|
|
|
|
|
if exists(mask):
|
|
|
|
|
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
|
|
|
|
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
|
|
|
|
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
|
|
|
|
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
|
|
|
|
|
|
|
|
|
time_embed = self.to_time_embeds(diffusion_timesteps)
|
|
|
|
|
|
|
|
|
|
@@ -902,7 +944,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
|
|
|
|
|
# attend
|
|
|
|
|
|
|
|
|
|
tokens = self.causal_transformer(tokens, mask = mask)
|
|
|
|
|
tokens = self.causal_transformer(tokens)
|
|
|
|
|
|
|
|
|
|
# get learned query, which should predict the image embedding (per DDPM timestep)
|
|
|
|
|
|
|
|
|
|
@@ -1147,12 +1189,12 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
batch_size = text.shape[0]
|
|
|
|
|
image_embed_dim = self.image_embed_dim
|
|
|
|
|
|
|
|
|
|
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
|
|
|
|
text_embed, text_encodings = self.clip.embed_text(text)
|
|
|
|
|
|
|
|
|
|
text_cond = dict(text_embed = text_embed)
|
|
|
|
|
|
|
|
|
|
if self.condition_on_text_encodings:
|
|
|
|
|
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
|
|
|
|
text_cond = {**text_cond, 'text_encodings': text_encodings}
|
|
|
|
|
|
|
|
|
|
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
|
|
|
|
|
|
|
|
|
|
@@ -1180,7 +1222,6 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
|
|
|
|
|
image_embed = None,
|
|
|
|
|
text_encodings = None, # as well as CLIP text encodings
|
|
|
|
|
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
@@ -1194,19 +1235,18 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
# calculate text conditionings, based on what is passed in
|
|
|
|
|
|
|
|
|
|
if exists(text):
|
|
|
|
|
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
|
|
|
|
text_embed, text_encodings = self.clip.embed_text(text)
|
|
|
|
|
|
|
|
|
|
text_cond = dict(text_embed = text_embed)
|
|
|
|
|
|
|
|
|
|
if self.condition_on_text_encodings:
|
|
|
|
|
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
|
|
|
|
|
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
|
|
|
|
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
|
|
|
|
text_cond = {**text_cond, 'text_encodings': text_encodings}
|
|
|
|
|
|
|
|
|
|
# timestep conditioning from ddpm
|
|
|
|
|
|
|
|
|
|
batch, device = image_embed.shape[0], image_embed.device
|
|
|
|
|
times = torch.randint(0, self.noise_scheduler.num_timesteps, (batch,), device = device, dtype = torch.long)
|
|
|
|
|
times = self.noise_scheduler.sample_random_times(batch)
|
|
|
|
|
|
|
|
|
|
# scale image embed (Katherine)
|
|
|
|
|
|
|
|
|
|
@@ -1218,17 +1258,44 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
# decoder
|
|
|
|
|
|
|
|
|
|
def ConvTransposeUpsample(dim, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
|
|
|
|
|
|
|
|
|
def NearestUpsample(dim, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
|
|
|
|
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
class PixelShuffleUpsample(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
|
|
|
|
|
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, dim, dim_out = None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
|
|
|
|
|
|
|
|
|
self.net = nn.Sequential(
|
|
|
|
|
conv,
|
|
|
|
|
nn.SiLU(),
|
|
|
|
|
nn.PixelShuffle(2)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.init_conv_(conv)
|
|
|
|
|
|
|
|
|
|
def init_conv_(self, conv):
|
|
|
|
|
o, i, h, w = conv.weight.shape
|
|
|
|
|
conv_weight = torch.empty(o // 4, i, h, w)
|
|
|
|
|
nn.init.kaiming_uniform_(conv_weight)
|
|
|
|
|
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
|
|
|
|
|
|
|
|
|
|
conv.weight.data.copy_(conv_weight)
|
|
|
|
|
nn.init.zeros_(conv.bias.data)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.net(x)
|
|
|
|
|
|
|
|
|
|
def Downsample(dim, *, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
|
|
|
|
@@ -1475,9 +1542,10 @@ class Unet(nn.Module):
|
|
|
|
|
self_attn = False,
|
|
|
|
|
attn_dim_head = 32,
|
|
|
|
|
attn_heads = 16,
|
|
|
|
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
|
|
|
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
|
|
|
|
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
|
|
|
|
|
sparse_attn = False,
|
|
|
|
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
|
|
|
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
|
|
|
|
cond_on_text_encodings = False,
|
|
|
|
|
max_text_len = 256,
|
|
|
|
|
cond_on_image_embeds = False,
|
|
|
|
|
@@ -1486,12 +1554,13 @@ class Unet(nn.Module):
|
|
|
|
|
init_conv_kernel_size = 7,
|
|
|
|
|
resnet_groups = 8,
|
|
|
|
|
num_resnet_blocks = 2,
|
|
|
|
|
init_cross_embed = True,
|
|
|
|
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
|
|
|
|
cross_embed_downsample = False,
|
|
|
|
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
|
|
|
|
memory_efficient = False,
|
|
|
|
|
scale_skip_connection = False,
|
|
|
|
|
nearest_upsample = False,
|
|
|
|
|
pixel_shuffle_upsample = True,
|
|
|
|
|
final_conv_kernel_size = 1,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
@@ -1514,7 +1583,7 @@ class Unet(nn.Module):
|
|
|
|
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
|
|
|
|
init_dim = default(init_dim, dim)
|
|
|
|
|
|
|
|
|
|
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
|
|
|
|
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
|
|
|
|
|
|
|
|
|
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
|
|
|
|
in_out = list(zip(dims[:-1], dims[1:]))
|
|
|
|
|
@@ -1564,6 +1633,17 @@ class Unet(nn.Module):
|
|
|
|
|
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
|
|
|
|
self.text_embed_dim = text_embed_dim
|
|
|
|
|
|
|
|
|
|
# low resolution noise conditiong, based on Imagen's upsampler training technique
|
|
|
|
|
|
|
|
|
|
self.lowres_noise_cond = lowres_noise_cond
|
|
|
|
|
|
|
|
|
|
self.to_lowres_noise_cond = nn.Sequential(
|
|
|
|
|
SinusoidalPosEmb(dim),
|
|
|
|
|
nn.Linear(dim, time_cond_dim),
|
|
|
|
|
nn.GELU(),
|
|
|
|
|
nn.Linear(time_cond_dim, time_cond_dim)
|
|
|
|
|
) if lowres_noise_cond else None
|
|
|
|
|
|
|
|
|
|
# finer control over whether to condition on image embeddings and text encodings
|
|
|
|
|
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
|
|
|
|
|
|
|
|
|
@@ -1605,7 +1685,7 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
# upsample klass
|
|
|
|
|
|
|
|
|
|
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
|
|
|
|
upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
|
|
|
|
|
|
|
|
|
# give memory efficient unet an initial resnet block
|
|
|
|
|
|
|
|
|
|
@@ -1667,7 +1747,12 @@ class Unet(nn.Module):
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
|
|
|
|
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
|
|
|
|
|
|
|
|
|
out_dim_in = dim + (channels if lowres_cond else 0)
|
|
|
|
|
|
|
|
|
|
self.to_out = nn.Conv2d(out_dim_in, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
|
|
|
|
|
|
|
|
|
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
|
|
|
|
|
|
|
|
|
# if the current settings for the unet are not correct
|
|
|
|
|
# for cascading DDPM, then reinit the unet with the right settings
|
|
|
|
|
@@ -1675,15 +1760,17 @@ class Unet(nn.Module):
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
lowres_cond,
|
|
|
|
|
lowres_noise_cond,
|
|
|
|
|
channels,
|
|
|
|
|
channels_out,
|
|
|
|
|
cond_on_image_embeds,
|
|
|
|
|
cond_on_text_encodings
|
|
|
|
|
cond_on_text_encodings,
|
|
|
|
|
):
|
|
|
|
|
if lowres_cond == self.lowres_cond and \
|
|
|
|
|
channels == self.channels and \
|
|
|
|
|
cond_on_image_embeds == self.cond_on_image_embeds and \
|
|
|
|
|
cond_on_text_encodings == self.cond_on_text_encodings and \
|
|
|
|
|
cond_on_lowres_noise == self.cond_on_lowres_noise and \
|
|
|
|
|
channels_out == self.channels_out:
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
@@ -1692,7 +1779,8 @@ class Unet(nn.Module):
|
|
|
|
|
channels = channels,
|
|
|
|
|
channels_out = channels_out,
|
|
|
|
|
cond_on_image_embeds = cond_on_image_embeds,
|
|
|
|
|
cond_on_text_encodings = cond_on_text_encodings
|
|
|
|
|
cond_on_text_encodings = cond_on_text_encodings,
|
|
|
|
|
lowres_noise_cond = lowres_noise_cond
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return self.__class__(**{**self._locals, **updated_kwargs})
|
|
|
|
|
@@ -1718,8 +1806,8 @@ class Unet(nn.Module):
|
|
|
|
|
*,
|
|
|
|
|
image_embed,
|
|
|
|
|
lowres_cond_img = None,
|
|
|
|
|
lowres_noise_level = None,
|
|
|
|
|
text_encodings = None,
|
|
|
|
|
text_mask = None,
|
|
|
|
|
image_cond_drop_prob = 0.,
|
|
|
|
|
text_cond_drop_prob = 0.,
|
|
|
|
|
blur_sigma = None,
|
|
|
|
|
@@ -1747,6 +1835,13 @@ class Unet(nn.Module):
|
|
|
|
|
time_tokens = self.to_time_tokens(time_hiddens)
|
|
|
|
|
t = self.to_time_cond(time_hiddens)
|
|
|
|
|
|
|
|
|
|
# low res noise conditioning (similar to time above)
|
|
|
|
|
|
|
|
|
|
if exists(lowres_noise_level):
|
|
|
|
|
assert exists(self.to_lowres_noise_cond), 'lowres_noise_cond must be set to True on instantiation of the unet in order to conditiong on lowres noise'
|
|
|
|
|
lowres_noise_level = lowres_noise_level.type_as(x)
|
|
|
|
|
t = t + self.to_lowres_noise_cond(lowres_noise_level)
|
|
|
|
|
|
|
|
|
|
# conditional dropout
|
|
|
|
|
|
|
|
|
|
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
|
|
|
|
@@ -1791,23 +1886,27 @@ class Unet(nn.Module):
|
|
|
|
|
text_tokens = None
|
|
|
|
|
|
|
|
|
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
|
|
|
|
assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {text_encodings.shape} - required batch size is {batch_size}'
|
|
|
|
|
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
|
|
|
|
|
|
|
|
|
text_mask = torch.any(text_encodings != 0., dim = -1)
|
|
|
|
|
|
|
|
|
|
text_tokens = self.text_to_cond(text_encodings)
|
|
|
|
|
|
|
|
|
|
text_tokens = text_tokens[:, :self.max_text_len]
|
|
|
|
|
text_mask = text_mask[:, :self.max_text_len]
|
|
|
|
|
|
|
|
|
|
text_tokens_len = text_tokens.shape[1]
|
|
|
|
|
remainder = self.max_text_len - text_tokens_len
|
|
|
|
|
|
|
|
|
|
if remainder > 0:
|
|
|
|
|
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
|
|
|
|
|
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
|
|
|
|
|
|
|
|
|
if exists(text_mask):
|
|
|
|
|
if remainder > 0:
|
|
|
|
|
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
|
|
|
|
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
|
|
|
|
|
|
|
|
|
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
|
|
|
|
text_keep_mask = text_mask & text_keep_mask
|
|
|
|
|
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}. text encoding is of shape {text_encodings.shape}'
|
|
|
|
|
text_keep_mask = text_mask & text_keep_mask
|
|
|
|
|
|
|
|
|
|
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
|
|
|
|
|
|
|
|
|
@@ -1854,7 +1953,7 @@ class Unet(nn.Module):
|
|
|
|
|
hiddens.append(x)
|
|
|
|
|
|
|
|
|
|
x = attn(x)
|
|
|
|
|
hiddens.append(x)
|
|
|
|
|
hiddens.append(x.contiguous())
|
|
|
|
|
|
|
|
|
|
if exists(post_downsample):
|
|
|
|
|
x = post_downsample(x)
|
|
|
|
|
@@ -1882,49 +1981,82 @@ class Unet(nn.Module):
|
|
|
|
|
x = torch.cat((x, r), dim = 1)
|
|
|
|
|
|
|
|
|
|
x = self.final_resnet_block(x, t)
|
|
|
|
|
|
|
|
|
|
if exists(lowres_cond_img):
|
|
|
|
|
x = torch.cat((x, lowres_cond_img), dim = 1)
|
|
|
|
|
|
|
|
|
|
return self.to_out(x)
|
|
|
|
|
|
|
|
|
|
class LowresConditioner(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
downsample_first = True,
|
|
|
|
|
downsample_mode_nearest = False,
|
|
|
|
|
use_blur = True,
|
|
|
|
|
blur_prob = 0.5,
|
|
|
|
|
blur_sigma = 0.6,
|
|
|
|
|
blur_kernel_size = 3,
|
|
|
|
|
input_image_range = None
|
|
|
|
|
use_noise = False,
|
|
|
|
|
input_image_range = None,
|
|
|
|
|
normalize_img_fn = identity,
|
|
|
|
|
unnormalize_img_fn = identity
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.downsample_first = downsample_first
|
|
|
|
|
self.downsample_mode_nearest = downsample_mode_nearest
|
|
|
|
|
|
|
|
|
|
self.input_image_range = input_image_range
|
|
|
|
|
|
|
|
|
|
self.use_blur = use_blur
|
|
|
|
|
self.blur_prob = blur_prob
|
|
|
|
|
self.blur_sigma = blur_sigma
|
|
|
|
|
self.blur_kernel_size = blur_kernel_size
|
|
|
|
|
|
|
|
|
|
self.use_noise = use_noise
|
|
|
|
|
self.normalize_img = normalize_img_fn
|
|
|
|
|
self.unnormalize_img = unnormalize_img_fn
|
|
|
|
|
self.noise_scheduler = NoiseScheduler(beta_schedule = 'linear', timesteps = 1000, loss_type = 'l2') if use_noise else None
|
|
|
|
|
|
|
|
|
|
def noise_image(self, cond_fmap, noise_levels = None):
|
|
|
|
|
assert exists(self.noise_scheduler)
|
|
|
|
|
|
|
|
|
|
batch = cond_fmap.shape[0]
|
|
|
|
|
cond_fmap = self.normalize_img(cond_fmap)
|
|
|
|
|
|
|
|
|
|
random_noise_levels = default(noise_levels, lambda: self.noise_scheduler.sample_random_times(batch))
|
|
|
|
|
cond_fmap = self.noise_scheduler.q_sample(cond_fmap, t = random_noise_levels, noise = torch.randn_like(cond_fmap))
|
|
|
|
|
|
|
|
|
|
cond_fmap = self.unnormalize_img(cond_fmap)
|
|
|
|
|
return cond_fmap, random_noise_levels
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
cond_fmap,
|
|
|
|
|
*,
|
|
|
|
|
target_image_size,
|
|
|
|
|
downsample_image_size = None,
|
|
|
|
|
should_blur = True,
|
|
|
|
|
blur_sigma = None,
|
|
|
|
|
blur_kernel_size = None
|
|
|
|
|
):
|
|
|
|
|
if self.training and self.downsample_first and exists(downsample_image_size):
|
|
|
|
|
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest)
|
|
|
|
|
if self.downsample_first and exists(downsample_image_size):
|
|
|
|
|
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = True)
|
|
|
|
|
|
|
|
|
|
# blur is only applied 50% of the time
|
|
|
|
|
# section 3.1 in https://arxiv.org/abs/2106.15282
|
|
|
|
|
|
|
|
|
|
if self.use_blur and should_blur and random.random() < self.blur_prob:
|
|
|
|
|
|
|
|
|
|
if self.training:
|
|
|
|
|
# when training, blur the low resolution conditional image
|
|
|
|
|
|
|
|
|
|
blur_sigma = default(blur_sigma, self.blur_sigma)
|
|
|
|
|
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
|
|
|
|
|
|
|
|
|
# allow for drawing a random sigma between lo and hi float values
|
|
|
|
|
|
|
|
|
|
if isinstance(blur_sigma, tuple):
|
|
|
|
|
blur_sigma = tuple(map(float, blur_sigma))
|
|
|
|
|
blur_sigma = random.uniform(*blur_sigma)
|
|
|
|
|
|
|
|
|
|
# allow for drawing a random kernel size between lo and hi int values
|
|
|
|
|
|
|
|
|
|
if isinstance(blur_kernel_size, tuple):
|
|
|
|
|
blur_kernel_size = tuple(map(int, blur_kernel_size))
|
|
|
|
|
kernel_size_lo, kernel_size_hi = blur_kernel_size
|
|
|
|
|
@@ -1932,9 +2064,21 @@ class LowresConditioner(nn.Module):
|
|
|
|
|
|
|
|
|
|
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
|
|
|
|
|
|
|
|
|
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range)
|
|
|
|
|
# resize to target image size
|
|
|
|
|
|
|
|
|
|
return cond_fmap
|
|
|
|
|
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range, nearest = True)
|
|
|
|
|
|
|
|
|
|
# noise conditioning, as done in Imagen
|
|
|
|
|
# as a replacement for the BSR noising, and potentially replace blurring for first stage too
|
|
|
|
|
|
|
|
|
|
random_noise_levels = None
|
|
|
|
|
|
|
|
|
|
if self.use_noise:
|
|
|
|
|
cond_fmap, random_noise_levels = self.noise_image(cond_fmap)
|
|
|
|
|
|
|
|
|
|
# return conditioning feature map, as well as the augmentation noise levels
|
|
|
|
|
|
|
|
|
|
return cond_fmap, random_noise_levels
|
|
|
|
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
@@ -1955,10 +2099,13 @@ class Decoder(nn.Module):
|
|
|
|
|
predict_x_start_for_latent_diffusion = False,
|
|
|
|
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
|
|
|
|
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
|
|
|
|
use_noise_for_lowres_cond = False, # whether to use Imagen-like noising for low resolution conditioning
|
|
|
|
|
use_blur_for_lowres_cond = True, # whether to use the blur conditioning used in the original cascading ddpm paper, as well as DALL-E2
|
|
|
|
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
|
|
|
|
lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution
|
|
|
|
|
blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time
|
|
|
|
|
blur_sigma = 0.6, # cascading ddpm - blur sigma
|
|
|
|
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
|
|
|
|
lowres_noise_sample_level = 0.2, # in imagen paper, they use a 0.2 noise level at sample time for low resolution conditioning
|
|
|
|
|
clip_denoised = True,
|
|
|
|
|
clip_x_start = True,
|
|
|
|
|
clip_adapter_overrides = dict(),
|
|
|
|
|
@@ -2006,10 +2153,17 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.channels = channels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# normalize and unnormalize image functions
|
|
|
|
|
|
|
|
|
|
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
|
|
|
|
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
|
|
|
|
|
|
|
|
|
# verify conditioning method
|
|
|
|
|
|
|
|
|
|
unets = cast_tuple(unet)
|
|
|
|
|
num_unets = len(unets)
|
|
|
|
|
self.num_unets = num_unets
|
|
|
|
|
|
|
|
|
|
self.unconditional = unconditional
|
|
|
|
|
|
|
|
|
|
@@ -2025,12 +2179,28 @@ class Decoder(nn.Module):
|
|
|
|
|
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
|
|
|
|
|
self.vb_loss_weight = vb_loss_weight
|
|
|
|
|
|
|
|
|
|
# default and validate conditioning parameters
|
|
|
|
|
|
|
|
|
|
use_noise_for_lowres_cond = cast_tuple(use_noise_for_lowres_cond, num_unets - 1, validate = False)
|
|
|
|
|
use_blur_for_lowres_cond = cast_tuple(use_blur_for_lowres_cond, num_unets - 1, validate = False)
|
|
|
|
|
|
|
|
|
|
if len(use_noise_for_lowres_cond) < num_unets:
|
|
|
|
|
use_noise_for_lowres_cond = (False, *use_noise_for_lowres_cond)
|
|
|
|
|
|
|
|
|
|
if len(use_blur_for_lowres_cond) < num_unets:
|
|
|
|
|
use_blur_for_lowres_cond = (False, *use_blur_for_lowres_cond)
|
|
|
|
|
|
|
|
|
|
assert not use_noise_for_lowres_cond[0], 'first unet will never need low res noise conditioning'
|
|
|
|
|
assert not use_blur_for_lowres_cond[0], 'first unet will never need low res blur conditioning'
|
|
|
|
|
|
|
|
|
|
assert num_unets == 1 or all((use_noise or use_blur) for use_noise, use_blur in zip(use_noise_for_lowres_cond[1:], use_blur_for_lowres_cond[1:]))
|
|
|
|
|
|
|
|
|
|
# construct unets and vaes
|
|
|
|
|
|
|
|
|
|
self.unets = nn.ModuleList([])
|
|
|
|
|
self.vaes = nn.ModuleList([])
|
|
|
|
|
|
|
|
|
|
for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)):
|
|
|
|
|
for ind, (one_unet, one_vae, one_unet_learned_var, lowres_noise_cond) in enumerate(zip(unets, vaes, learned_variance, use_noise_for_lowres_cond)):
|
|
|
|
|
assert isinstance(one_unet, Unet)
|
|
|
|
|
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
|
|
|
|
|
|
|
|
|
|
@@ -2042,6 +2212,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
one_unet = one_unet.cast_model_parameters(
|
|
|
|
|
lowres_cond = not is_first,
|
|
|
|
|
lowres_noise_cond = lowres_noise_cond,
|
|
|
|
|
cond_on_image_embeds = not unconditional and is_first,
|
|
|
|
|
cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings,
|
|
|
|
|
channels = unet_channels,
|
|
|
|
|
@@ -2084,13 +2255,14 @@ class Decoder(nn.Module):
|
|
|
|
|
image_sizes = default(image_sizes, (image_size,))
|
|
|
|
|
image_sizes = tuple(sorted(set(image_sizes)))
|
|
|
|
|
|
|
|
|
|
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
|
|
|
|
assert self.num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({self.num_unets}) for resolutions {image_sizes}'
|
|
|
|
|
self.image_sizes = image_sizes
|
|
|
|
|
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
|
|
|
|
|
|
|
|
|
|
# random crop sizes (for super-resoluting unets at the end of cascade?)
|
|
|
|
|
|
|
|
|
|
self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))
|
|
|
|
|
assert not exists(self.random_crop_sizes[0]), 'you would not need to randomly crop the image for the base unet'
|
|
|
|
|
|
|
|
|
|
# predict x0 config
|
|
|
|
|
|
|
|
|
|
@@ -2103,15 +2275,30 @@ class Decoder(nn.Module):
|
|
|
|
|
# cascading ddpm related stuff
|
|
|
|
|
|
|
|
|
|
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
|
|
|
|
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
|
|
|
|
assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
|
|
|
|
|
|
|
|
|
self.to_lowres_cond = LowresConditioner(
|
|
|
|
|
downsample_first = lowres_downsample_first,
|
|
|
|
|
downsample_mode_nearest = lowres_downsample_mode_nearest,
|
|
|
|
|
blur_sigma = blur_sigma,
|
|
|
|
|
blur_kernel_size = blur_kernel_size,
|
|
|
|
|
input_image_range = self.input_image_range
|
|
|
|
|
)
|
|
|
|
|
self.lowres_conds = nn.ModuleList([])
|
|
|
|
|
|
|
|
|
|
for unet_index, use_noise, use_blur in zip(range(num_unets), use_noise_for_lowres_cond, use_blur_for_lowres_cond):
|
|
|
|
|
if unet_index == 0:
|
|
|
|
|
self.lowres_conds.append(None)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
lowres_cond = LowresConditioner(
|
|
|
|
|
downsample_first = lowres_downsample_first,
|
|
|
|
|
use_blur = use_blur,
|
|
|
|
|
use_noise = use_noise,
|
|
|
|
|
blur_prob = blur_prob,
|
|
|
|
|
blur_sigma = blur_sigma,
|
|
|
|
|
blur_kernel_size = blur_kernel_size,
|
|
|
|
|
input_image_range = self.input_image_range,
|
|
|
|
|
normalize_img_fn = self.normalize_img,
|
|
|
|
|
unnormalize_img_fn = self.unnormalize_img
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.lowres_conds.append(lowres_cond)
|
|
|
|
|
|
|
|
|
|
self.lowres_noise_sample_level = lowres_noise_sample_level
|
|
|
|
|
|
|
|
|
|
# classifier free guidance
|
|
|
|
|
|
|
|
|
|
@@ -2129,11 +2316,6 @@ class Decoder(nn.Module):
|
|
|
|
|
self.use_dynamic_thres = use_dynamic_thres
|
|
|
|
|
self.dynamic_thres_percentile = dynamic_thres_percentile
|
|
|
|
|
|
|
|
|
|
# normalize and unnormalize image functions
|
|
|
|
|
|
|
|
|
|
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
|
|
|
|
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
|
|
|
|
|
|
|
|
|
# device tracker
|
|
|
|
|
|
|
|
|
|
self.register_buffer('_dummy', torch.Tensor([True]), persistent = False)
|
|
|
|
|
@@ -2147,7 +2329,7 @@ class Decoder(nn.Module):
|
|
|
|
|
return any([unet.cond_on_text_encodings for unet in self.unets])
|
|
|
|
|
|
|
|
|
|
def get_unet(self, unet_number):
|
|
|
|
|
assert 0 < unet_number <= len(self.unets)
|
|
|
|
|
assert 0 < unet_number <= self.num_unets
|
|
|
|
|
index = unet_number - 1
|
|
|
|
|
return self.unets[index]
|
|
|
|
|
|
|
|
|
|
@@ -2189,10 +2371,10 @@ class Decoder(nn.Module):
|
|
|
|
|
x = x.clamp(-s, s) / s
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
|
|
|
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
|
|
|
|
|
|
|
|
|
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
|
|
|
|
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level))
|
|
|
|
|
|
|
|
|
|
if learned_variance:
|
|
|
|
|
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
|
|
|
|
@@ -2224,16 +2406,16 @@ class Decoder(nn.Module):
|
|
|
|
|
return model_mean, posterior_variance, posterior_log_variance
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
|
|
|
|
b, *_, device = *x.shape, x.device
|
|
|
|
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
|
|
|
|
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
noise = torch.randn_like(x)
|
|
|
|
|
# no noise when t == 0
|
|
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
|
|
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
|
|
|
|
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None):
|
|
|
|
|
device = self.device
|
|
|
|
|
|
|
|
|
|
b = shape[0]
|
|
|
|
|
@@ -2249,9 +2431,9 @@ class Decoder(nn.Module):
|
|
|
|
|
torch.full((b,), i, device = device, dtype = torch.long),
|
|
|
|
|
image_embed = image_embed,
|
|
|
|
|
text_encodings = text_encodings,
|
|
|
|
|
text_mask = text_mask,
|
|
|
|
|
cond_scale = cond_scale,
|
|
|
|
|
lowres_cond_img = lowres_cond_img,
|
|
|
|
|
lowres_noise_level = lowres_noise_level,
|
|
|
|
|
predict_x_start = predict_x_start,
|
|
|
|
|
noise_scheduler = noise_scheduler,
|
|
|
|
|
learned_variance = learned_variance,
|
|
|
|
|
@@ -2262,7 +2444,7 @@ class Decoder(nn.Module):
|
|
|
|
|
return unnormalize_img
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
|
|
|
|
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None):
|
|
|
|
|
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
|
|
|
|
|
|
|
|
|
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
|
|
|
|
@@ -2272,13 +2454,16 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
img = torch.randn(shape, device = device)
|
|
|
|
|
|
|
|
|
|
if not is_latent_diffusion:
|
|
|
|
|
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
|
|
|
|
alpha = alphas[time]
|
|
|
|
|
alpha_next = alphas[time_next]
|
|
|
|
|
|
|
|
|
|
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
|
|
|
|
|
|
|
|
|
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
|
|
|
|
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
|
|
|
|
|
if learned_variance:
|
|
|
|
|
pred, _ = pred.chunk(2, dim = 1)
|
|
|
|
|
@@ -2317,7 +2502,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
|
|
|
|
|
|
|
|
|
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
|
|
|
|
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
# normalize to [-1, 1]
|
|
|
|
|
@@ -2335,8 +2520,8 @@ class Decoder(nn.Module):
|
|
|
|
|
times,
|
|
|
|
|
image_embed = image_embed,
|
|
|
|
|
text_encodings = text_encodings,
|
|
|
|
|
text_mask = text_mask,
|
|
|
|
|
lowres_cond_img = lowres_cond_img,
|
|
|
|
|
lowres_noise_level = lowres_noise_level,
|
|
|
|
|
image_cond_drop_prob = self.image_cond_drop_prob,
|
|
|
|
|
text_cond_drop_prob = self.text_cond_drop_prob,
|
|
|
|
|
)
|
|
|
|
|
@@ -2395,7 +2580,6 @@ class Decoder(nn.Module):
|
|
|
|
|
self,
|
|
|
|
|
image_embed = None,
|
|
|
|
|
text = None,
|
|
|
|
|
text_mask = None,
|
|
|
|
|
text_encodings = None,
|
|
|
|
|
batch_size = 1,
|
|
|
|
|
cond_scale = 1.,
|
|
|
|
|
@@ -2409,27 +2593,31 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
|
|
|
|
assert exists(self.clip)
|
|
|
|
|
_, text_encodings, text_mask = self.clip.embed_text(text)
|
|
|
|
|
_, text_encodings = self.clip.embed_text(text)
|
|
|
|
|
|
|
|
|
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
|
|
|
|
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
|
|
|
|
|
|
|
|
|
if self.condition_on_text_encodings:
|
|
|
|
|
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
|
|
|
|
|
|
|
|
|
img = None
|
|
|
|
|
is_cuda = next(self.parameters()).is_cuda
|
|
|
|
|
|
|
|
|
|
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)):
|
|
|
|
|
num_unets = self.num_unets
|
|
|
|
|
cond_scale = cast_tuple(cond_scale, num_unets)
|
|
|
|
|
|
|
|
|
|
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
|
|
|
|
|
|
|
|
|
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
|
|
|
|
|
|
|
|
|
with context:
|
|
|
|
|
lowres_cond_img = None
|
|
|
|
|
lowres_cond_img = lowres_noise_level = None
|
|
|
|
|
shape = (batch_size, channel, image_size, image_size)
|
|
|
|
|
|
|
|
|
|
if unet.lowres_cond:
|
|
|
|
|
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
|
|
|
|
lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = True)
|
|
|
|
|
|
|
|
|
|
if lowres_cond.use_noise:
|
|
|
|
|
lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device)
|
|
|
|
|
lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level)
|
|
|
|
|
|
|
|
|
|
is_latent_diffusion = isinstance(vae, VQGanVAE)
|
|
|
|
|
image_size = vae.get_encoded_fmap_size(image_size)
|
|
|
|
|
@@ -2442,12 +2630,12 @@ class Decoder(nn.Module):
|
|
|
|
|
shape,
|
|
|
|
|
image_embed = image_embed,
|
|
|
|
|
text_encodings = text_encodings,
|
|
|
|
|
text_mask = text_mask,
|
|
|
|
|
cond_scale = cond_scale,
|
|
|
|
|
cond_scale = unet_cond_scale,
|
|
|
|
|
predict_x_start = predict_x_start,
|
|
|
|
|
learned_variance = learned_variance,
|
|
|
|
|
clip_denoised = not is_latent_diffusion,
|
|
|
|
|
lowres_cond_img = lowres_cond_img,
|
|
|
|
|
lowres_noise_level = lowres_noise_level,
|
|
|
|
|
is_latent_diffusion = is_latent_diffusion,
|
|
|
|
|
noise_scheduler = noise_scheduler,
|
|
|
|
|
timesteps = sample_timesteps
|
|
|
|
|
@@ -2466,11 +2654,10 @@ class Decoder(nn.Module):
|
|
|
|
|
text = None,
|
|
|
|
|
image_embed = None,
|
|
|
|
|
text_encodings = None,
|
|
|
|
|
text_mask = None,
|
|
|
|
|
unet_number = None,
|
|
|
|
|
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
|
|
|
|
|
):
|
|
|
|
|
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
|
|
|
|
assert not (self.num_unets > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {self.num_unets}, if you are training cascading DDPM (multiple unets)'
|
|
|
|
|
unet_number = default(unet_number, 1)
|
|
|
|
|
unet_index = unet_number - 1
|
|
|
|
|
|
|
|
|
|
@@ -2478,6 +2665,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
vae = self.vaes[unet_index]
|
|
|
|
|
noise_scheduler = self.noise_schedulers[unet_index]
|
|
|
|
|
lowres_conditioner = self.lowres_conds[unet_index]
|
|
|
|
|
target_image_size = self.image_sizes[unet_index]
|
|
|
|
|
predict_x_start = self.predict_x_start[unet_index]
|
|
|
|
|
random_crop_size = self.random_crop_sizes[unet_index]
|
|
|
|
|
@@ -2495,16 +2683,13 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
|
|
|
|
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
|
|
|
|
_, text_encodings, text_mask = self.clip.embed_text(text)
|
|
|
|
|
_, text_encodings = self.clip.embed_text(text)
|
|
|
|
|
|
|
|
|
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
|
|
|
|
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
|
|
|
|
|
|
|
|
|
if self.condition_on_text_encodings:
|
|
|
|
|
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
|
|
|
|
|
|
|
|
|
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
|
|
|
|
image = resize_image_to(image, target_image_size)
|
|
|
|
|
lowres_cond_img, lowres_noise_level = lowres_conditioner(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if exists(lowres_conditioner) else (None, None)
|
|
|
|
|
image = resize_image_to(image, target_image_size, nearest = True)
|
|
|
|
|
|
|
|
|
|
if exists(random_crop_size):
|
|
|
|
|
aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
|
|
|
|
|
@@ -2521,7 +2706,7 @@ class Decoder(nn.Module):
|
|
|
|
|
image = vae.encode(image)
|
|
|
|
|
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
|
|
|
|
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
|
|
|
|
|
if not return_lowres_cond_image:
|
|
|
|
|
return losses
|
|
|
|
|
|