|
|
|
|
@@ -278,6 +278,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|
|
|
|
import clip
|
|
|
|
|
openai_clip, preprocess = clip.load(name)
|
|
|
|
|
super().__init__(openai_clip)
|
|
|
|
|
self.eos_id = 49407 # for handling 0 being also '!'
|
|
|
|
|
|
|
|
|
|
text_attention_final = self.find_layer('ln_final')
|
|
|
|
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
|
|
|
|
@@ -316,7 +317,10 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def embed_text(self, text):
|
|
|
|
|
text = text[..., :self.max_text_len]
|
|
|
|
|
text_mask = text != 0
|
|
|
|
|
|
|
|
|
|
is_eos_id = (text == self.eos_id)
|
|
|
|
|
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
|
|
|
|
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
|
|
|
|
assert not self.cleared
|
|
|
|
|
|
|
|
|
|
text_embed = self.clip.encode_text(text)
|
|
|
|
|
@@ -527,25 +531,31 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
# diffusion prior
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module):
|
|
|
|
|
def __init__(self, dim, eps = 1e-5):
|
|
|
|
|
def __init__(self, dim, eps = 1e-5, stable = False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.eps = eps
|
|
|
|
|
self.stable = stable
|
|
|
|
|
self.g = nn.Parameter(torch.ones(dim))
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = x / x.amax(dim = -1, keepdim = True).detach()
|
|
|
|
|
if self.stable:
|
|
|
|
|
x = x / x.amax(dim = -1, keepdim = True).detach()
|
|
|
|
|
|
|
|
|
|
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
|
|
|
|
mean = torch.mean(x, dim = -1, keepdim = True)
|
|
|
|
|
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
|
|
|
|
|
|
|
|
|
class ChanLayerNorm(nn.Module):
|
|
|
|
|
def __init__(self, dim, eps = 1e-5):
|
|
|
|
|
def __init__(self, dim, eps = 1e-5, stable = False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.eps = eps
|
|
|
|
|
self.stable = stable
|
|
|
|
|
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = x / x.amax(dim = 1, keepdim = True).detach()
|
|
|
|
|
if self.stable:
|
|
|
|
|
x = x / x.amax(dim = 1, keepdim = True).detach()
|
|
|
|
|
|
|
|
|
|
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
|
|
|
|
mean = torch.mean(x, dim = 1, keepdim = True)
|
|
|
|
|
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
|
|
|
|
@@ -669,7 +679,7 @@ class Attention(nn.Module):
|
|
|
|
|
dropout = 0.,
|
|
|
|
|
causal = False,
|
|
|
|
|
rotary_emb = None,
|
|
|
|
|
pb_relax_alpha = 32 ** 2
|
|
|
|
|
pb_relax_alpha = 128
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.pb_relax_alpha = pb_relax_alpha
|
|
|
|
|
@@ -760,6 +770,7 @@ class CausalTransformer(nn.Module):
|
|
|
|
|
dim_head = 64,
|
|
|
|
|
heads = 8,
|
|
|
|
|
ff_mult = 4,
|
|
|
|
|
norm_in = False,
|
|
|
|
|
norm_out = True,
|
|
|
|
|
attn_dropout = 0.,
|
|
|
|
|
ff_dropout = 0.,
|
|
|
|
|
@@ -768,6 +779,8 @@ class CausalTransformer(nn.Module):
|
|
|
|
|
rotary_emb = True
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
|
|
|
|
|
|
|
|
|
|
self.rel_pos_bias = RelPosBias(heads = heads)
|
|
|
|
|
|
|
|
|
|
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
|
|
|
|
@@ -779,20 +792,18 @@ class CausalTransformer(nn.Module):
|
|
|
|
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
|
|
|
|
self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
|
|
|
|
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
x,
|
|
|
|
|
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
|
|
|
|
):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
n, device = x.shape[1], x.device
|
|
|
|
|
|
|
|
|
|
x = self.init_norm(x)
|
|
|
|
|
|
|
|
|
|
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
|
|
|
|
|
|
|
|
|
for attn, ff in self.layers:
|
|
|
|
|
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
|
|
|
|
x = attn(x, attn_bias = attn_bias) + x
|
|
|
|
|
x = ff(x) + x
|
|
|
|
|
|
|
|
|
|
out = self.norm(x)
|
|
|
|
|
@@ -806,6 +817,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
num_time_embeds = 1,
|
|
|
|
|
num_image_embeds = 1,
|
|
|
|
|
num_text_embeds = 1,
|
|
|
|
|
max_text_len = 256,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@@ -831,6 +843,11 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
self.learned_query = nn.Parameter(torch.randn(dim))
|
|
|
|
|
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
|
|
|
|
|
|
|
|
|
# dalle1 learned padding strategy
|
|
|
|
|
|
|
|
|
|
self.max_text_len = max_text_len
|
|
|
|
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
|
|
|
|
|
|
|
|
|
def forward_with_cond_scale(
|
|
|
|
|
self,
|
|
|
|
|
*args,
|
|
|
|
|
@@ -852,7 +869,6 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
*,
|
|
|
|
|
text_embed,
|
|
|
|
|
text_encodings = None,
|
|
|
|
|
mask = None,
|
|
|
|
|
cond_drop_prob = 0.
|
|
|
|
|
):
|
|
|
|
|
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
|
|
|
|
@@ -870,9 +886,29 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
|
|
|
|
|
if not exists(text_encodings):
|
|
|
|
|
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = torch.any(text_encodings != 0., dim = -1)
|
|
|
|
|
|
|
|
|
|
# replace any padding in the text encodings with learned padding tokens unique across position
|
|
|
|
|
|
|
|
|
|
text_encodings = text_encodings[:, :self.max_text_len]
|
|
|
|
|
mask = mask[:, :self.max_text_len]
|
|
|
|
|
|
|
|
|
|
text_len = text_encodings.shape[-2]
|
|
|
|
|
remainder = self.max_text_len - text_len
|
|
|
|
|
|
|
|
|
|
if remainder > 0:
|
|
|
|
|
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
|
|
|
|
|
mask = F.pad(mask, (0, remainder), value = False)
|
|
|
|
|
|
|
|
|
|
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
|
|
|
|
|
|
|
|
|
|
text_encodings = torch.where(
|
|
|
|
|
rearrange(mask, 'b n -> b n 1').clone(),
|
|
|
|
|
text_encodings,
|
|
|
|
|
null_text_embeds
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# classifier free guidance
|
|
|
|
|
|
|
|
|
|
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
|
|
|
|
|
@@ -905,7 +941,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
|
|
|
|
|
# attend
|
|
|
|
|
|
|
|
|
|
tokens = self.causal_transformer(tokens, mask = mask)
|
|
|
|
|
tokens = self.causal_transformer(tokens)
|
|
|
|
|
|
|
|
|
|
# get learned query, which should predict the image embedding (per DDPM timestep)
|
|
|
|
|
|
|
|
|
|
@@ -1219,6 +1255,14 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
# decoder
|
|
|
|
|
|
|
|
|
|
def NearestUpsample(dim, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
|
|
|
|
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
class PixelShuffleUpsample(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
|
|
|
|
|
@@ -1625,7 +1669,7 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
# upsample klass
|
|
|
|
|
|
|
|
|
|
upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
|
|
|
|
upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
|
|
|
|
|
|
|
|
|
# give memory efficient unet an initial resnet block
|
|
|
|
|
|
|
|
|
|
@@ -1812,6 +1856,7 @@ class Unet(nn.Module):
|
|
|
|
|
text_tokens = None
|
|
|
|
|
|
|
|
|
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
|
|
|
|
assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {text_encodings.shape} - required batch size is {batch_size}'
|
|
|
|
|
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
|
|
|
|
|
|
|
|
|
text_mask = torch.any(text_encodings != 0., dim = -1)
|
|
|
|
|
@@ -1913,6 +1958,7 @@ class LowresConditioner(nn.Module):
|
|
|
|
|
self,
|
|
|
|
|
downsample_first = True,
|
|
|
|
|
downsample_mode_nearest = False,
|
|
|
|
|
blur_prob = 0.5,
|
|
|
|
|
blur_sigma = 0.6,
|
|
|
|
|
blur_kernel_size = 3,
|
|
|
|
|
input_image_range = None
|
|
|
|
|
@@ -1923,6 +1969,7 @@ class LowresConditioner(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.input_image_range = input_image_range
|
|
|
|
|
|
|
|
|
|
self.blur_prob = blur_prob
|
|
|
|
|
self.blur_sigma = blur_sigma
|
|
|
|
|
self.blur_kernel_size = blur_kernel_size
|
|
|
|
|
|
|
|
|
|
@@ -1935,20 +1982,27 @@ class LowresConditioner(nn.Module):
|
|
|
|
|
blur_sigma = None,
|
|
|
|
|
blur_kernel_size = None
|
|
|
|
|
):
|
|
|
|
|
if self.training and self.downsample_first and exists(downsample_image_size):
|
|
|
|
|
if self.downsample_first and exists(downsample_image_size):
|
|
|
|
|
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest)
|
|
|
|
|
|
|
|
|
|
if self.training:
|
|
|
|
|
# blur is only applied 50% of the time
|
|
|
|
|
# section 3.1 in https://arxiv.org/abs/2106.15282
|
|
|
|
|
|
|
|
|
|
if random.random() < self.blur_prob:
|
|
|
|
|
|
|
|
|
|
# when training, blur the low resolution conditional image
|
|
|
|
|
|
|
|
|
|
blur_sigma = default(blur_sigma, self.blur_sigma)
|
|
|
|
|
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
|
|
|
|
|
|
|
|
|
# allow for drawing a random sigma between lo and hi float values
|
|
|
|
|
|
|
|
|
|
if isinstance(blur_sigma, tuple):
|
|
|
|
|
blur_sigma = tuple(map(float, blur_sigma))
|
|
|
|
|
blur_sigma = random.uniform(*blur_sigma)
|
|
|
|
|
|
|
|
|
|
# allow for drawing a random kernel size between lo and hi int values
|
|
|
|
|
|
|
|
|
|
if isinstance(blur_kernel_size, tuple):
|
|
|
|
|
blur_kernel_size = tuple(map(int, blur_kernel_size))
|
|
|
|
|
kernel_size_lo, kernel_size_hi = blur_kernel_size
|
|
|
|
|
@@ -1957,7 +2011,6 @@ class LowresConditioner(nn.Module):
|
|
|
|
|
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
|
|
|
|
|
|
|
|
|
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range)
|
|
|
|
|
|
|
|
|
|
return cond_fmap
|
|
|
|
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
|
|
|
@@ -1981,6 +2034,7 @@ class Decoder(nn.Module):
|
|
|
|
|
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
|
|
|
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
|
|
|
|
lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution
|
|
|
|
|
blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time
|
|
|
|
|
blur_sigma = 0.6, # cascading ddpm - blur sigma
|
|
|
|
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
|
|
|
|
clip_denoised = True,
|
|
|
|
|
@@ -2129,9 +2183,12 @@ class Decoder(nn.Module):
|
|
|
|
|
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
|
|
|
|
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
|
|
|
|
|
|
|
|
|
self.lowres_downsample_mode_nearest = lowres_downsample_mode_nearest
|
|
|
|
|
|
|
|
|
|
self.to_lowres_cond = LowresConditioner(
|
|
|
|
|
downsample_first = lowres_downsample_first,
|
|
|
|
|
downsample_mode_nearest = lowres_downsample_mode_nearest,
|
|
|
|
|
blur_prob = blur_prob,
|
|
|
|
|
blur_sigma = blur_sigma,
|
|
|
|
|
blur_kernel_size = blur_kernel_size,
|
|
|
|
|
input_image_range = self.input_image_range
|
|
|
|
|
@@ -2295,6 +2352,9 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
img = torch.randn(shape, device = device)
|
|
|
|
|
|
|
|
|
|
if not is_latent_diffusion:
|
|
|
|
|
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
|
|
|
|
alpha = alphas[time]
|
|
|
|
|
alpha_next = alphas[time_next]
|
|
|
|
|
@@ -2447,7 +2507,7 @@ class Decoder(nn.Module):
|
|
|
|
|
shape = (batch_size, channel, image_size, image_size)
|
|
|
|
|
|
|
|
|
|
if unet.lowres_cond:
|
|
|
|
|
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
|
|
|
|
lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = self.lowres_downsample_mode_nearest)
|
|
|
|
|
|
|
|
|
|
is_latent_diffusion = isinstance(vae, VQGanVAE)
|
|
|
|
|
image_size = vae.get_encoded_fmap_size(image_size)
|
|
|
|
|
|