|
|
|
|
@@ -77,6 +77,11 @@ def cast_tuple(val, length = None):
|
|
|
|
|
def module_device(module):
|
|
|
|
|
return next(module.parameters()).device
|
|
|
|
|
|
|
|
|
|
def zero_init_(m):
|
|
|
|
|
nn.init.zeros_(m.weight)
|
|
|
|
|
if exists(m.bias):
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def null_context(*args, **kwargs):
|
|
|
|
|
yield
|
|
|
|
|
@@ -160,7 +165,7 @@ def unnormalize_zero_to_one(normed_img):
|
|
|
|
|
|
|
|
|
|
# clip related adapters
|
|
|
|
|
|
|
|
|
|
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
|
|
|
|
|
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings'])
|
|
|
|
|
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
|
|
|
|
|
|
|
|
|
class BaseClipAdapter(nn.Module):
|
|
|
|
|
@@ -220,7 +225,8 @@ class XClipAdapter(BaseClipAdapter):
|
|
|
|
|
encoder_output = self.clip.text_transformer(text)
|
|
|
|
|
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
|
|
|
|
text_embed = self.clip.to_text_latent(text_cls)
|
|
|
|
|
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
|
|
|
|
|
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
|
|
|
|
return EmbeddedText(l2norm(text_embed), text_encodings)
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def embed_image(self, image):
|
|
|
|
|
@@ -255,7 +261,8 @@ class CoCaAdapter(BaseClipAdapter):
|
|
|
|
|
text = text[..., :self.max_text_len]
|
|
|
|
|
text_mask = text != 0
|
|
|
|
|
text_embed, text_encodings = self.clip.embed_text(text)
|
|
|
|
|
return EmbeddedText(text_embed, text_encodings, text_mask)
|
|
|
|
|
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
|
|
|
|
return EmbeddedText(text_embed, text_encodings)
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def embed_image(self, image):
|
|
|
|
|
@@ -314,8 +321,9 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|
|
|
|
|
|
|
|
|
text_embed = self.clip.encode_text(text)
|
|
|
|
|
text_encodings = self.text_encodings
|
|
|
|
|
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
|
|
|
|
del self.text_encodings
|
|
|
|
|
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
|
|
|
|
|
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def embed_image(self, image):
|
|
|
|
|
@@ -752,6 +760,7 @@ class CausalTransformer(nn.Module):
|
|
|
|
|
dim_head = 64,
|
|
|
|
|
heads = 8,
|
|
|
|
|
ff_mult = 4,
|
|
|
|
|
norm_in = False,
|
|
|
|
|
norm_out = True,
|
|
|
|
|
attn_dropout = 0.,
|
|
|
|
|
ff_dropout = 0.,
|
|
|
|
|
@@ -760,6 +769,8 @@ class CausalTransformer(nn.Module):
|
|
|
|
|
rotary_emb = True
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
|
|
|
|
|
|
|
|
|
|
self.rel_pos_bias = RelPosBias(heads = heads)
|
|
|
|
|
|
|
|
|
|
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
|
|
|
|
@@ -774,17 +785,15 @@ class CausalTransformer(nn.Module):
|
|
|
|
|
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
|
|
|
|
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
x,
|
|
|
|
|
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
|
|
|
|
):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
n, device = x.shape[1], x.device
|
|
|
|
|
|
|
|
|
|
x = self.init_norm(x)
|
|
|
|
|
|
|
|
|
|
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
|
|
|
|
|
|
|
|
|
for attn, ff in self.layers:
|
|
|
|
|
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
|
|
|
|
x = attn(x, attn_bias = attn_bias) + x
|
|
|
|
|
x = ff(x) + x
|
|
|
|
|
|
|
|
|
|
out = self.norm(x)
|
|
|
|
|
@@ -798,6 +807,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
num_time_embeds = 1,
|
|
|
|
|
num_image_embeds = 1,
|
|
|
|
|
num_text_embeds = 1,
|
|
|
|
|
max_text_len = 256,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@@ -823,6 +833,11 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
self.learned_query = nn.Parameter(torch.randn(dim))
|
|
|
|
|
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
|
|
|
|
|
|
|
|
|
# dalle1 learned padding strategy
|
|
|
|
|
|
|
|
|
|
self.max_text_len = max_text_len
|
|
|
|
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
|
|
|
|
|
|
|
|
|
def forward_with_cond_scale(
|
|
|
|
|
self,
|
|
|
|
|
*args,
|
|
|
|
|
@@ -844,7 +859,6 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
*,
|
|
|
|
|
text_embed,
|
|
|
|
|
text_encodings = None,
|
|
|
|
|
mask = None,
|
|
|
|
|
cond_drop_prob = 0.
|
|
|
|
|
):
|
|
|
|
|
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
|
|
|
|
@@ -862,9 +876,28 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
|
|
|
|
|
if not exists(text_encodings):
|
|
|
|
|
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
|
|
|
|
|
|
|
|
|
mask = torch.any(text_encodings != 0., dim = -1)
|
|
|
|
|
|
|
|
|
|
if not exists(mask):
|
|
|
|
|
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
|
|
|
|
|
# replace any padding in the text encodings with learned padding tokens unique across position
|
|
|
|
|
|
|
|
|
|
text_encodings = text_encodings[:, :self.max_text_len]
|
|
|
|
|
mask = mask[:, :self.max_text_len]
|
|
|
|
|
|
|
|
|
|
text_len = text_encodings.shape[-2]
|
|
|
|
|
remainder = self.max_text_len - text_len
|
|
|
|
|
|
|
|
|
|
if remainder > 0:
|
|
|
|
|
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
|
|
|
|
|
mask = F.pad(mask, (0, remainder), value = False)
|
|
|
|
|
|
|
|
|
|
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
|
|
|
|
|
|
|
|
|
|
text_encodings = torch.where(
|
|
|
|
|
rearrange(mask, 'b n -> b n 1'),
|
|
|
|
|
text_encodings,
|
|
|
|
|
null_text_embeds
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# classifier free guidance
|
|
|
|
|
|
|
|
|
|
@@ -881,9 +914,8 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
|
|
|
|
# but let's just do it right
|
|
|
|
|
|
|
|
|
|
if exists(mask):
|
|
|
|
|
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
|
|
|
|
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
|
|
|
|
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
|
|
|
|
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
|
|
|
|
|
|
|
|
|
time_embed = self.to_time_embeds(diffusion_timesteps)
|
|
|
|
|
|
|
|
|
|
@@ -899,7 +931,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
|
|
|
|
|
# attend
|
|
|
|
|
|
|
|
|
|
tokens = self.causal_transformer(tokens, mask = mask)
|
|
|
|
|
tokens = self.causal_transformer(tokens)
|
|
|
|
|
|
|
|
|
|
# get learned query, which should predict the image embedding (per DDPM timestep)
|
|
|
|
|
|
|
|
|
|
@@ -922,11 +954,12 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
loss_type = "l2",
|
|
|
|
|
predict_x_start = True,
|
|
|
|
|
beta_schedule = "cosine",
|
|
|
|
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
|
|
|
|
sampling_clamp_l2norm = False,
|
|
|
|
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
|
|
|
|
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
|
|
|
|
|
sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
|
|
|
|
|
training_clamp_l2norm = False,
|
|
|
|
|
init_image_embed_l2norm = False,
|
|
|
|
|
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
|
|
|
|
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
|
|
|
|
clip_adapter_overrides = dict()
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@@ -963,23 +996,32 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
self.condition_on_text_encodings = condition_on_text_encodings
|
|
|
|
|
|
|
|
|
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
|
|
|
|
|
|
|
|
|
self.predict_x_start = predict_x_start
|
|
|
|
|
|
|
|
|
|
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
|
|
|
|
|
|
|
|
|
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
|
|
|
|
|
|
|
|
|
|
# whether to force an l2norm, similar to clipping denoised, when sampling
|
|
|
|
|
|
|
|
|
|
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
|
|
|
|
self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
|
|
|
|
|
|
|
|
|
|
self.training_clamp_l2norm = training_clamp_l2norm
|
|
|
|
|
self.init_image_embed_l2norm = init_image_embed_l2norm
|
|
|
|
|
|
|
|
|
|
# device tracker
|
|
|
|
|
|
|
|
|
|
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def device(self):
|
|
|
|
|
return self._dummy.device
|
|
|
|
|
|
|
|
|
|
def l2norm_clamp_embed(self, image_embed):
|
|
|
|
|
return l2norm(image_embed) * self.image_embed_scale
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
|
|
|
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
|
|
|
|
|
|
|
|
|
@@ -1020,6 +1062,9 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
|
|
|
|
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
|
|
|
|
|
|
|
|
|
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
|
|
|
|
image_embed = self.l2norm_clamp_embed(image_embed)
|
|
|
|
|
|
|
|
|
|
return image_embed
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
@@ -1055,15 +1100,18 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
x_start.clamp_(-1., 1.)
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start and self.sampling_clamp_l2norm:
|
|
|
|
|
x_start = l2norm(x_start) * self.image_embed_scale
|
|
|
|
|
x_start = self.l2norm_clamp_embed(x_start)
|
|
|
|
|
|
|
|
|
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
|
|
|
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
|
|
|
|
new_noise = torch.randn_like(image_embed)
|
|
|
|
|
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
|
|
|
|
|
|
|
|
|
|
img = x_start * alpha_next.sqrt() + \
|
|
|
|
|
c1 * new_noise + \
|
|
|
|
|
c2 * pred_noise
|
|
|
|
|
image_embed = x_start * alpha_next.sqrt() + \
|
|
|
|
|
c1 * noise + \
|
|
|
|
|
c2 * pred_noise
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start and self.sampling_final_clamp_l2norm:
|
|
|
|
|
image_embed = self.l2norm_clamp_embed(image_embed)
|
|
|
|
|
|
|
|
|
|
return image_embed
|
|
|
|
|
|
|
|
|
|
@@ -1091,7 +1139,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start and self.training_clamp_l2norm:
|
|
|
|
|
pred = l2norm(pred) * self.image_embed_scale
|
|
|
|
|
pred = self.l2norm_clamp_embed(pred)
|
|
|
|
|
|
|
|
|
|
target = noise if not self.predict_x_start else image_embed
|
|
|
|
|
|
|
|
|
|
@@ -1128,12 +1176,12 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
batch_size = text.shape[0]
|
|
|
|
|
image_embed_dim = self.image_embed_dim
|
|
|
|
|
|
|
|
|
|
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
|
|
|
|
text_embed, text_encodings = self.clip.embed_text(text)
|
|
|
|
|
|
|
|
|
|
text_cond = dict(text_embed = text_embed)
|
|
|
|
|
|
|
|
|
|
if self.condition_on_text_encodings:
|
|
|
|
|
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
|
|
|
|
text_cond = {**text_cond, 'text_encodings': text_encodings}
|
|
|
|
|
|
|
|
|
|
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
|
|
|
|
|
|
|
|
|
|
@@ -1161,7 +1209,6 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
|
|
|
|
|
image_embed = None,
|
|
|
|
|
text_encodings = None, # as well as CLIP text encodings
|
|
|
|
|
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
@@ -1175,13 +1222,13 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
# calculate text conditionings, based on what is passed in
|
|
|
|
|
|
|
|
|
|
if exists(text):
|
|
|
|
|
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
|
|
|
|
text_embed, text_encodings = self.clip.embed_text(text)
|
|
|
|
|
|
|
|
|
|
text_cond = dict(text_embed = text_embed)
|
|
|
|
|
|
|
|
|
|
if self.condition_on_text_encodings:
|
|
|
|
|
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
|
|
|
|
|
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
|
|
|
|
text_cond = {**text_cond, 'text_encodings': text_encodings}
|
|
|
|
|
|
|
|
|
|
# timestep conditioning from ddpm
|
|
|
|
|
|
|
|
|
|
@@ -1198,16 +1245,35 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
# decoder
|
|
|
|
|
|
|
|
|
|
def ConvTransposeUpsample(dim, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
|
|
|
|
class PixelShuffleUpsample(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
|
|
|
|
|
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, dim, dim_out = None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
|
|
|
|
|
|
|
|
|
def NearestUpsample(dim, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
|
|
|
|
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
|
|
|
|
)
|
|
|
|
|
self.net = nn.Sequential(
|
|
|
|
|
conv,
|
|
|
|
|
nn.SiLU(),
|
|
|
|
|
nn.PixelShuffle(2)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.init_conv_(conv)
|
|
|
|
|
|
|
|
|
|
def init_conv_(self, conv):
|
|
|
|
|
o, i, h, w = conv.weight.shape
|
|
|
|
|
conv_weight = torch.empty(o // 4, i, h, w)
|
|
|
|
|
nn.init.kaiming_uniform_(conv_weight)
|
|
|
|
|
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
|
|
|
|
|
|
|
|
|
|
conv.weight.data.copy_(conv_weight)
|
|
|
|
|
nn.init.zeros_(conv.bias.data)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.net(x)
|
|
|
|
|
|
|
|
|
|
def Downsample(dim, *, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
@@ -1471,7 +1537,7 @@ class Unet(nn.Module):
|
|
|
|
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
|
|
|
|
memory_efficient = False,
|
|
|
|
|
scale_skip_connection = False,
|
|
|
|
|
nearest_upsample = False,
|
|
|
|
|
pixel_shuffle_upsample = True,
|
|
|
|
|
final_conv_kernel_size = 1,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
@@ -1585,7 +1651,7 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
# upsample klass
|
|
|
|
|
|
|
|
|
|
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
|
|
|
|
upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
|
|
|
|
|
|
|
|
|
# give memory efficient unet an initial resnet block
|
|
|
|
|
|
|
|
|
|
@@ -1649,6 +1715,8 @@ class Unet(nn.Module):
|
|
|
|
|
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
|
|
|
|
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
|
|
|
|
|
|
|
|
|
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
|
|
|
|
|
|
|
|
|
# if the current settings for the unet are not correct
|
|
|
|
|
# for cascading DDPM, then reinit the unet with the right settings
|
|
|
|
|
def cast_model_parameters(
|
|
|
|
|
@@ -1699,7 +1767,6 @@ class Unet(nn.Module):
|
|
|
|
|
image_embed,
|
|
|
|
|
lowres_cond_img = None,
|
|
|
|
|
text_encodings = None,
|
|
|
|
|
text_mask = None,
|
|
|
|
|
image_cond_drop_prob = 0.,
|
|
|
|
|
text_cond_drop_prob = 0.,
|
|
|
|
|
blur_sigma = None,
|
|
|
|
|
@@ -1771,23 +1838,27 @@ class Unet(nn.Module):
|
|
|
|
|
text_tokens = None
|
|
|
|
|
|
|
|
|
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
|
|
|
|
assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {text_encodings.shape} - required batch size is {batch_size}'
|
|
|
|
|
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
|
|
|
|
|
|
|
|
|
text_mask = torch.any(text_encodings != 0., dim = -1)
|
|
|
|
|
|
|
|
|
|
text_tokens = self.text_to_cond(text_encodings)
|
|
|
|
|
|
|
|
|
|
text_tokens = text_tokens[:, :self.max_text_len]
|
|
|
|
|
text_mask = text_mask[:, :self.max_text_len]
|
|
|
|
|
|
|
|
|
|
text_tokens_len = text_tokens.shape[1]
|
|
|
|
|
remainder = self.max_text_len - text_tokens_len
|
|
|
|
|
|
|
|
|
|
if remainder > 0:
|
|
|
|
|
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
|
|
|
|
|
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
|
|
|
|
|
|
|
|
|
if exists(text_mask):
|
|
|
|
|
if remainder > 0:
|
|
|
|
|
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
|
|
|
|
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
|
|
|
|
|
|
|
|
|
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
|
|
|
|
text_keep_mask = text_mask & text_keep_mask
|
|
|
|
|
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}. text encoding is of shape {text_encodings.shape}'
|
|
|
|
|
text_keep_mask = text_mask & text_keep_mask
|
|
|
|
|
|
|
|
|
|
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
|
|
|
|
|
|
|
|
|
@@ -2047,7 +2118,7 @@ class Decoder(nn.Module):
|
|
|
|
|
self.noise_schedulers = nn.ModuleList([])
|
|
|
|
|
|
|
|
|
|
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
|
|
|
|
|
assert sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
|
|
|
|
|
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
|
|
|
|
|
|
|
|
|
|
noise_scheduler = NoiseScheduler(
|
|
|
|
|
beta_schedule = unet_beta_schedule,
|
|
|
|
|
@@ -2169,10 +2240,10 @@ class Decoder(nn.Module):
|
|
|
|
|
x = x.clamp(-s, s) / s
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
|
|
|
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
|
|
|
|
|
|
|
|
|
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
|
|
|
|
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
|
|
|
|
|
|
|
|
|
if learned_variance:
|
|
|
|
|
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
|
|
|
|
@@ -2204,16 +2275,16 @@ class Decoder(nn.Module):
|
|
|
|
|
return model_mean, posterior_variance, posterior_log_variance
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
|
|
|
|
b, *_, device = *x.shape, x.device
|
|
|
|
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
|
|
|
|
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
|
|
|
|
|
noise = torch.randn_like(x)
|
|
|
|
|
# no noise when t == 0
|
|
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
|
|
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
|
|
|
|
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
|
|
|
|
|
device = self.device
|
|
|
|
|
|
|
|
|
|
b = shape[0]
|
|
|
|
|
@@ -2229,7 +2300,6 @@ class Decoder(nn.Module):
|
|
|
|
|
torch.full((b,), i, device = device, dtype = torch.long),
|
|
|
|
|
image_embed = image_embed,
|
|
|
|
|
text_encodings = text_encodings,
|
|
|
|
|
text_mask = text_mask,
|
|
|
|
|
cond_scale = cond_scale,
|
|
|
|
|
lowres_cond_img = lowres_cond_img,
|
|
|
|
|
predict_x_start = predict_x_start,
|
|
|
|
|
@@ -2242,7 +2312,7 @@ class Decoder(nn.Module):
|
|
|
|
|
return unnormalize_img
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
|
|
|
|
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
|
|
|
|
|
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
|
|
|
|
|
|
|
|
|
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
|
|
|
|
@@ -2258,7 +2328,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
|
|
|
|
|
|
|
|
|
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
|
|
|
|
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
if learned_variance:
|
|
|
|
|
pred, _ = pred.chunk(2, dim = 1)
|
|
|
|
|
@@ -2275,9 +2345,10 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
|
|
|
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
|
|
|
|
noise = torch.randn_like(img) if time_next > 0 else 0.
|
|
|
|
|
|
|
|
|
|
img = x_start * alpha_next.sqrt() + \
|
|
|
|
|
c1 * torch.randn_like(img) + \
|
|
|
|
|
c1 * noise + \
|
|
|
|
|
c2 * pred_noise
|
|
|
|
|
|
|
|
|
|
img = self.unnormalize_img(img)
|
|
|
|
|
@@ -2296,7 +2367,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
|
|
|
|
|
|
|
|
|
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
|
|
|
|
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
# normalize to [-1, 1]
|
|
|
|
|
@@ -2314,7 +2385,6 @@ class Decoder(nn.Module):
|
|
|
|
|
times,
|
|
|
|
|
image_embed = image_embed,
|
|
|
|
|
text_encodings = text_encodings,
|
|
|
|
|
text_mask = text_mask,
|
|
|
|
|
lowres_cond_img = lowres_cond_img,
|
|
|
|
|
image_cond_drop_prob = self.image_cond_drop_prob,
|
|
|
|
|
text_cond_drop_prob = self.text_cond_drop_prob,
|
|
|
|
|
@@ -2374,7 +2444,6 @@ class Decoder(nn.Module):
|
|
|
|
|
self,
|
|
|
|
|
image_embed = None,
|
|
|
|
|
text = None,
|
|
|
|
|
text_mask = None,
|
|
|
|
|
text_encodings = None,
|
|
|
|
|
batch_size = 1,
|
|
|
|
|
cond_scale = 1.,
|
|
|
|
|
@@ -2388,7 +2457,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
|
|
|
|
assert exists(self.clip)
|
|
|
|
|
_, text_encodings, text_mask = self.clip.embed_text(text)
|
|
|
|
|
_, text_encodings = self.clip.embed_text(text)
|
|
|
|
|
|
|
|
|
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
|
|
|
|
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
|
|
|
|
@@ -2418,7 +2487,6 @@ class Decoder(nn.Module):
|
|
|
|
|
shape,
|
|
|
|
|
image_embed = image_embed,
|
|
|
|
|
text_encodings = text_encodings,
|
|
|
|
|
text_mask = text_mask,
|
|
|
|
|
cond_scale = cond_scale,
|
|
|
|
|
predict_x_start = predict_x_start,
|
|
|
|
|
learned_variance = learned_variance,
|
|
|
|
|
@@ -2442,7 +2510,6 @@ class Decoder(nn.Module):
|
|
|
|
|
text = None,
|
|
|
|
|
image_embed = None,
|
|
|
|
|
text_encodings = None,
|
|
|
|
|
text_mask = None,
|
|
|
|
|
unet_number = None,
|
|
|
|
|
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
|
|
|
|
|
):
|
|
|
|
|
@@ -2471,7 +2538,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
|
|
|
|
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
|
|
|
|
_, text_encodings, text_mask = self.clip.embed_text(text)
|
|
|
|
|
_, text_encodings = self.clip.embed_text(text)
|
|
|
|
|
|
|
|
|
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
|
|
|
|
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
|
|
|
|
@@ -2494,7 +2561,7 @@ class Decoder(nn.Module):
|
|
|
|
|
image = vae.encode(image)
|
|
|
|
|
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
|
|
|
|
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
|
|
|
|
|
|
|
|
|
if not return_lowres_cond_image:
|
|
|
|
|
return losses
|
|
|
|
|
|