From 68e9883f5965e170640b6bbed07c63df7bb662b3 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 14 Apr 2022 10:10:04 -0700 Subject: [PATCH] use cross attention for conditioning unet based on image embedding tokens (which opens up the door on conditioning on text encodings as well --- README.md | 4 +- dalle2_pytorch/dalle2_pytorch.py | 162 +++++++++++++++++++++++-------- setup.py | 2 +- 3 files changed, 124 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 8626de3..59925e8 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ clip = CLIP( unet = Unet( dim = 128, image_embed_dim = 512, - time_dim = 128, + cond_dim = 128, channels = 3, dim_mults=(1, 2, 4, 8) ).cuda() @@ -264,7 +264,7 @@ loss.backward() unet = Unet( dim = 128, image_embed_dim = 512, - time_dim = 128, + cond_dim = 128, channels = 3, dim_mults=(1, 2, 4, 8) ).cuda() diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index fb67a2d..b2c4fbd 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -118,14 +118,13 @@ class ChanRMSNorm(RMSNorm): inv_norm = torch.rsqrt(squared_sum + self.eps) return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale -class PreNormResidual(nn.Module): - def __init__(self, dim, fn): +class Residual(nn.Module): + def __init__(self, fn): super().__init__() self.fn = fn - self.norm = RMSNorm(dim) def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) + x + return self.fn(x, **kwargs) + x # mlp @@ -240,6 +239,7 @@ class Attention(nn.Module): sim = sim - sim.amax(dim = -1, keepdim = True) attn = sim.softmax(dim = -1) + attn = self.dropout(attn) out = einsum('b h i j, b j d -> b h i d', attn, v) @@ -571,6 +571,17 @@ def Upsample(dim): def Downsample(dim): return nn.Conv2d(dim, dim, 4, 2, 1) +class Blur(nn.Module): + def __init__(self): + super().__init__() + filt = torch.Tensor([1, 2, 1]) + self.register_buffer('filt', filt) + + def forward(self, x): + filt = self.filt + filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1') + return filter2d(x, filt, normalized = True) + class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() @@ -598,10 +609,17 @@ class ConvNextBlock(nn.Module): super().__init__() need_projection = dim != dim_out - self.mlp = nn.Sequential( - nn.GELU(), - nn.Linear(cond_dim, dim) - ) if exists(cond_dim) else None + self.cross_attn = None + + if exists(cond_dim): + self.cross_attn = EinopsToAndFrom( + 'b c h w', + 'b (h w) c', + CrossAttention( + dim = dim, + context_dim = cond_dim + ) + ) self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim) @@ -618,21 +636,82 @@ class ConvNextBlock(nn.Module): def forward(self, x, cond = None): h = self.ds_conv(x) - if exists(self.mlp): + if exists(self.cross_attn): assert exists(cond) - condition = self.mlp(cond) - h = h + rearrange(condition, 'b c -> b c 1 1') + h = self.cross_attn(h, context = cond) + h h = self.net(h) + return h + self.res_conv(x) +class CrossAttention(nn.Module): + def __init__( + self, + dim, + *, + context_dim = None, + dim_head = 64, + heads = 8, + dropout = 0., + ): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + inner_dim = dim_head * heads + + context_dim = default(context_dim, dim) + + self.norm = RMSNorm(dim) + self.norm_context = RMSNorm(context_dim) + self.dropout = nn.Dropout(dropout) + + self.null_kv = nn.Parameter(torch.randn(2, dim_head)) + self.to_q = nn.Linear(dim, inner_dim, bias = False) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) + self.to_out = nn.Linear(inner_dim, dim, bias = False) + + def forward(self, x, context, mask = None): + b, n, device = *x.shape[:2], x.device + + x = self.norm(x) + context = self.norm_context(context) + + q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) + + q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads) + + # add null key / value for classifier free guidance in prior net + + nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b) + + k = torch.cat((nk, k), dim = -2) + v = torch.cat((nv, v), dim = -2) + + q = q * self.scale + + sim = einsum('b h i d, b h j d -> b h i j', q, k) + max_neg_value = -torch.finfo(sim.dtype).max + + if exists(mask): + mask = F.pad(mask, (1, 0), value = True) + mask = rearrange(mask, 'b j -> b 1 1 j') + sim = sim.masked_fill(~mask, max_neg_value) + + sim = sim - sim.amax(dim = -1, keepdim = True) + attn = sim.softmax(dim = -1) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + class Unet(nn.Module): def __init__( self, dim, *, image_embed_dim, - time_dim = None, + cond_dim = None, + num_image_tokens = 4, out_dim = None, dim_mults=(1, 2, 4, 8), channels = 3, @@ -643,18 +722,28 @@ class Unet(nn.Module): dims = [channels, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - time_dim = default(time_dim, dim) + # time and image embeddings + + cond_dim = default(cond_dim, dim) self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), nn.Linear(dim, dim * 4), nn.GELU(), - nn.Linear(dim * 4, dim) + nn.Linear(dim * 4, cond_dim), + Rearrange('b d -> b 1 d') ) - self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim)) + self.image_to_cond = nn.Sequential( + nn.Linear(image_embed_dim, cond_dim * num_image_tokens), + Rearrange('b (n d) -> b n d', n = num_image_tokens) + ) if image_embed_dim != cond_dim else nn.Identity() - cond_dim = time_dim + image_embed_dim + # for classifier free guidance + + self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) + + # layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) @@ -664,7 +753,7 @@ class Unet(nn.Module): is_last = ind >= (num_resolutions - 1) self.downs.append(nn.ModuleList([ - ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0), + ConvNextBlock(dim_in, dim_out, norm = ind != 0), ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim), Downsample(dim_out) if not is_last else nn.Identity() ])) @@ -672,7 +761,7 @@ class Unet(nn.Module): mid_dim = dims[-1] self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) - self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim))) + self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): @@ -714,52 +803,43 @@ class Unet(nn.Module): cond_drop_prob = 0. ): batch_size, device = x.shape[0], x.device - t = self.time_mlp(time) + time_tokens = self.time_mlp(time) cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device) # mask out image embedding depending on condition dropout # for classifier free guidance - image_embed = torch.where( - rearrange(cond_prob_mask, 'b -> b 1'), - image_embed, - rearrange(self.null_image_embed, 'd -> 1 d') + image_tokens = self.image_to_cond(image_embed) + + image_tokens = torch.where( + rearrange(cond_prob_mask, 'b -> b 1 1'), + image_tokens, + self.null_image_embed ) - t = torch.cat((t, image_embed), dim = -1) + c = torch.cat((time_tokens, image_tokens), dim = -2) # c for condition hiddens = [] for convnext, convnext2, downsample in self.downs: - x = convnext(x, t) - x = convnext2(x, t) + x = convnext(x, c) + x = convnext2(x, c) hiddens.append(x) x = downsample(x) - x = self.mid_block1(x, t) + x = self.mid_block1(x, c) x = self.mid_attn(x) - x = self.mid_block2(x, t) + x = self.mid_block2(x, c) for convnext, convnext2, upsample in self.ups: x = torch.cat((x, hiddens.pop()), dim=1) - x = convnext(x, t) - x = convnext2(x, t) + x = convnext(x, c) + x = convnext2(x, c) x = upsample(x) return self.final_conv(x) -class Blur(nn.Module): - def __init__(self): - super().__init__() - filt = torch.Tensor([1, 2, 1]) - self.register_buffer('filt', filt) - - def forward(self, x): - filt = self.filt - filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1') - return filter2d(x, filt, normalized = True) - class Decoder(nn.Module): def __init__( self, diff --git a/setup.py b/setup.py index 99c68f8..ae2d847 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.10', + version = '0.0.11', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',