use cross attention for conditioning unet based on image embedding tokens (which opens up the door on conditioning on text encodings as well

This commit is contained in:
Phil Wang
2022-04-14 10:10:04 -07:00
parent 95b018374a
commit 68e9883f59
3 changed files with 124 additions and 44 deletions

View File

@@ -118,14 +118,13 @@ class ChanRMSNorm(RMSNorm):
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.norm = RMSNorm(dim)
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x
return self.fn(x, **kwargs) + x
# mlp
@@ -240,6 +239,7 @@ class Attention(nn.Module):
sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b j d -> b h i d', attn, v)
@@ -571,6 +571,17 @@ def Upsample(dim):
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -598,10 +609,17 @@ class ConvNextBlock(nn.Module):
super().__init__()
need_projection = dim != dim_out
self.mlp = nn.Sequential(
nn.GELU(),
nn.Linear(cond_dim, dim)
) if exists(cond_dim) else None
self.cross_attn = None
if exists(cond_dim):
self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim,
context_dim = cond_dim
)
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
@@ -618,21 +636,82 @@ class ConvNextBlock(nn.Module):
def forward(self, x, cond = None):
h = self.ds_conv(x)
if exists(self.mlp):
if exists(self.cross_attn):
assert exists(cond)
condition = self.mlp(cond)
h = h + rearrange(condition, 'b c -> b c 1 1')
h = self.cross_attn(h, context = cond) + h
h = self.net(h)
return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
dropout = 0.,
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.norm = RMSNorm(dim)
self.norm_context = RMSNorm(context_dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Unet(nn.Module):
def __init__(
self,
dim,
*,
image_embed_dim,
time_dim = None,
cond_dim = None,
num_image_tokens = 4,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
@@ -643,18 +722,28 @@ class Unet(nn.Module):
dims = [channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
time_dim = default(time_dim, dim)
# time and image embeddings
cond_dim = default(cond_dim, dim)
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
nn.Linear(dim * 4, cond_dim),
Rearrange('b d -> b 1 d')
)
self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim))
self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
cond_dim = time_dim + image_embed_dim
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
@@ -664,7 +753,7 @@ class Unet(nn.Module):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0),
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
@@ -672,7 +761,7 @@ class Unet(nn.Module):
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
@@ -714,52 +803,43 @@ class Unet(nn.Module):
cond_drop_prob = 0.
):
batch_size, device = x.shape[0], x.device
t = self.time_mlp(time)
time_tokens = self.time_mlp(time)
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_embed = torch.where(
rearrange(cond_prob_mask, 'b -> b 1'),
image_embed,
rearrange(self.null_image_embed, 'd -> 1 d')
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
rearrange(cond_prob_mask, 'b -> b 1 1'),
image_tokens,
self.null_image_embed
)
t = torch.cat((t, image_embed), dim = -1)
c = torch.cat((time_tokens, image_tokens), dim = -2) # c for condition
hiddens = []
for convnext, convnext2, downsample in self.downs:
x = convnext(x, t)
x = convnext2(x, t)
x = convnext(x, c)
x = convnext2(x, c)
hiddens.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_block1(x, c)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
x = self.mid_block2(x, c)
for convnext, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, t)
x = convnext2(x, t)
x = convnext(x, c)
x = convnext2(x, c)
x = upsample(x)
return self.final_conv(x)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class Decoder(nn.Module):
def __init__(
self,