mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
use cross attention for conditioning unet based on image embedding tokens (which opens up the door on conditioning on text encodings as well
This commit is contained in:
@@ -101,7 +101,7 @@ clip = CLIP(
|
|||||||
unet = Unet(
|
unet = Unet(
|
||||||
dim = 128,
|
dim = 128,
|
||||||
image_embed_dim = 512,
|
image_embed_dim = 512,
|
||||||
time_dim = 128,
|
cond_dim = 128,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
dim_mults=(1, 2, 4, 8)
|
dim_mults=(1, 2, 4, 8)
|
||||||
).cuda()
|
).cuda()
|
||||||
@@ -264,7 +264,7 @@ loss.backward()
|
|||||||
unet = Unet(
|
unet = Unet(
|
||||||
dim = 128,
|
dim = 128,
|
||||||
image_embed_dim = 512,
|
image_embed_dim = 512,
|
||||||
time_dim = 128,
|
cond_dim = 128,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
dim_mults=(1, 2, 4, 8)
|
dim_mults=(1, 2, 4, 8)
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|||||||
@@ -118,14 +118,13 @@ class ChanRMSNorm(RMSNorm):
|
|||||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||||
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
|
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
|
||||||
|
|
||||||
class PreNormResidual(nn.Module):
|
class Residual(nn.Module):
|
||||||
def __init__(self, dim, fn):
|
def __init__(self, fn):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.norm = RMSNorm(dim)
|
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
def forward(self, x, **kwargs):
|
||||||
return self.fn(self.norm(x), **kwargs) + x
|
return self.fn(x, **kwargs) + x
|
||||||
|
|
||||||
# mlp
|
# mlp
|
||||||
|
|
||||||
@@ -240,6 +239,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||||
attn = sim.softmax(dim = -1)
|
attn = sim.softmax(dim = -1)
|
||||||
|
attn = self.dropout(attn)
|
||||||
|
|
||||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||||
|
|
||||||
@@ -571,6 +571,17 @@ def Upsample(dim):
|
|||||||
def Downsample(dim):
|
def Downsample(dim):
|
||||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||||
|
|
||||||
|
class Blur(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
filt = torch.Tensor([1, 2, 1])
|
||||||
|
self.register_buffer('filt', filt)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
filt = self.filt
|
||||||
|
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
|
||||||
|
return filter2d(x, filt, normalized = True)
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -598,10 +609,17 @@ class ConvNextBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
need_projection = dim != dim_out
|
need_projection = dim != dim_out
|
||||||
|
|
||||||
self.mlp = nn.Sequential(
|
self.cross_attn = None
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(cond_dim, dim)
|
if exists(cond_dim):
|
||||||
) if exists(cond_dim) else None
|
self.cross_attn = EinopsToAndFrom(
|
||||||
|
'b c h w',
|
||||||
|
'b (h w) c',
|
||||||
|
CrossAttention(
|
||||||
|
dim = dim,
|
||||||
|
context_dim = cond_dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
||||||
|
|
||||||
@@ -618,21 +636,82 @@ class ConvNextBlock(nn.Module):
|
|||||||
def forward(self, x, cond = None):
|
def forward(self, x, cond = None):
|
||||||
h = self.ds_conv(x)
|
h = self.ds_conv(x)
|
||||||
|
|
||||||
if exists(self.mlp):
|
if exists(self.cross_attn):
|
||||||
assert exists(cond)
|
assert exists(cond)
|
||||||
condition = self.mlp(cond)
|
h = self.cross_attn(h, context = cond) + h
|
||||||
h = h + rearrange(condition, 'b c -> b c 1 1')
|
|
||||||
|
|
||||||
h = self.net(h)
|
h = self.net(h)
|
||||||
|
|
||||||
return h + self.res_conv(x)
|
return h + self.res_conv(x)
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
*,
|
||||||
|
context_dim = None,
|
||||||
|
dim_head = 64,
|
||||||
|
heads = 8,
|
||||||
|
dropout = 0.,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
self.heads = heads
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
|
context_dim = default(context_dim, dim)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(dim)
|
||||||
|
self.norm_context = RMSNorm(context_dim)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||||
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||||
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
||||||
|
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||||
|
|
||||||
|
def forward(self, x, context, mask = None):
|
||||||
|
b, n, device = *x.shape[:2], x.device
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
context = self.norm_context(context)
|
||||||
|
|
||||||
|
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||||
|
|
||||||
|
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
|
||||||
|
|
||||||
|
# add null key / value for classifier free guidance in prior net
|
||||||
|
|
||||||
|
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
|
||||||
|
|
||||||
|
k = torch.cat((nk, k), dim = -2)
|
||||||
|
v = torch.cat((nv, v), dim = -2)
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||||
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
mask = F.pad(mask, (1, 0), value = True)
|
||||||
|
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||||
|
sim = sim.masked_fill(~mask, max_neg_value)
|
||||||
|
|
||||||
|
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||||
|
attn = sim.softmax(dim = -1)
|
||||||
|
|
||||||
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||||
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
class Unet(nn.Module):
|
class Unet(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
*,
|
*,
|
||||||
image_embed_dim,
|
image_embed_dim,
|
||||||
time_dim = None,
|
cond_dim = None,
|
||||||
|
num_image_tokens = 4,
|
||||||
out_dim = None,
|
out_dim = None,
|
||||||
dim_mults=(1, 2, 4, 8),
|
dim_mults=(1, 2, 4, 8),
|
||||||
channels = 3,
|
channels = 3,
|
||||||
@@ -643,18 +722,28 @@ class Unet(nn.Module):
|
|||||||
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
||||||
in_out = list(zip(dims[:-1], dims[1:]))
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||||||
|
|
||||||
time_dim = default(time_dim, dim)
|
# time and image embeddings
|
||||||
|
|
||||||
|
cond_dim = default(cond_dim, dim)
|
||||||
|
|
||||||
self.time_mlp = nn.Sequential(
|
self.time_mlp = nn.Sequential(
|
||||||
SinusoidalPosEmb(dim),
|
SinusoidalPosEmb(dim),
|
||||||
nn.Linear(dim, dim * 4),
|
nn.Linear(dim, dim * 4),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(dim * 4, dim)
|
nn.Linear(dim * 4, cond_dim),
|
||||||
|
Rearrange('b d -> b 1 d')
|
||||||
)
|
)
|
||||||
|
|
||||||
self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim))
|
self.image_to_cond = nn.Sequential(
|
||||||
|
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||||
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||||
|
) if image_embed_dim != cond_dim else nn.Identity()
|
||||||
|
|
||||||
cond_dim = time_dim + image_embed_dim
|
# for classifier free guidance
|
||||||
|
|
||||||
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||||
|
|
||||||
|
# layers
|
||||||
|
|
||||||
self.downs = nn.ModuleList([])
|
self.downs = nn.ModuleList([])
|
||||||
self.ups = nn.ModuleList([])
|
self.ups = nn.ModuleList([])
|
||||||
@@ -664,7 +753,7 @@ class Unet(nn.Module):
|
|||||||
is_last = ind >= (num_resolutions - 1)
|
is_last = ind >= (num_resolutions - 1)
|
||||||
|
|
||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0),
|
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
||||||
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
|
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
|
||||||
Downsample(dim_out) if not is_last else nn.Identity()
|
Downsample(dim_out) if not is_last else nn.Identity()
|
||||||
]))
|
]))
|
||||||
@@ -672,7 +761,7 @@ class Unet(nn.Module):
|
|||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
|
|
||||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
|
||||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||||
|
|
||||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
@@ -714,52 +803,43 @@ class Unet(nn.Module):
|
|||||||
cond_drop_prob = 0.
|
cond_drop_prob = 0.
|
||||||
):
|
):
|
||||||
batch_size, device = x.shape[0], x.device
|
batch_size, device = x.shape[0], x.device
|
||||||
t = self.time_mlp(time)
|
time_tokens = self.time_mlp(time)
|
||||||
|
|
||||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||||
|
|
||||||
# mask out image embedding depending on condition dropout
|
# mask out image embedding depending on condition dropout
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
image_embed = torch.where(
|
image_tokens = self.image_to_cond(image_embed)
|
||||||
rearrange(cond_prob_mask, 'b -> b 1'),
|
|
||||||
image_embed,
|
image_tokens = torch.where(
|
||||||
rearrange(self.null_image_embed, 'd -> 1 d')
|
rearrange(cond_prob_mask, 'b -> b 1 1'),
|
||||||
|
image_tokens,
|
||||||
|
self.null_image_embed
|
||||||
)
|
)
|
||||||
|
|
||||||
t = torch.cat((t, image_embed), dim = -1)
|
c = torch.cat((time_tokens, image_tokens), dim = -2) # c for condition
|
||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
|
|
||||||
for convnext, convnext2, downsample in self.downs:
|
for convnext, convnext2, downsample in self.downs:
|
||||||
x = convnext(x, t)
|
x = convnext(x, c)
|
||||||
x = convnext2(x, t)
|
x = convnext2(x, c)
|
||||||
hiddens.append(x)
|
hiddens.append(x)
|
||||||
x = downsample(x)
|
x = downsample(x)
|
||||||
|
|
||||||
x = self.mid_block1(x, t)
|
x = self.mid_block1(x, c)
|
||||||
x = self.mid_attn(x)
|
x = self.mid_attn(x)
|
||||||
x = self.mid_block2(x, t)
|
x = self.mid_block2(x, c)
|
||||||
|
|
||||||
for convnext, convnext2, upsample in self.ups:
|
for convnext, convnext2, upsample in self.ups:
|
||||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||||
x = convnext(x, t)
|
x = convnext(x, c)
|
||||||
x = convnext2(x, t)
|
x = convnext2(x, c)
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
return self.final_conv(x)
|
return self.final_conv(x)
|
||||||
|
|
||||||
class Blur(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
filt = torch.Tensor([1, 2, 1])
|
|
||||||
self.register_buffer('filt', filt)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
filt = self.filt
|
|
||||||
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
|
|
||||||
return filter2d(x, filt, normalized = True)
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user