prepare non-causal attention, for use in the unet in the decoder

This commit is contained in:
Phil Wang
2022-04-13 12:04:09 -07:00
parent c9377efc93
commit e5e415297c

View File

@@ -116,12 +116,14 @@ class Attention(nn.Module):
dim, dim,
dim_head = 64, dim_head = 64,
heads = 8, heads = 8,
dropout = 0. dropout = 0.,
causal = False
): ):
super().__init__() super().__init__()
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
inner_dim = dim_head * heads inner_dim = dim_head * heads
self.causal = causal
self.norm = RMSNorm(dim) self.norm = RMSNorm(dim)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@@ -154,8 +156,9 @@ class Attention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j') mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value) sim = sim.masked_fill(~mask, max_neg_value)
causal_mask = torch.ones((n, n), dtype = torch.bool, device = device).triu(1) if self.causal:
sim = sim.masked_fill(causal_mask, max_neg_value) causal_mask = torch.ones((n, n), dtype = torch.bool, device = device).triu(1)
sim = sim.masked_fill(causal_mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True) sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1) attn = sim.softmax(dim = -1)
@@ -165,7 +168,7 @@ class Attention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out) return self.to_out(out)
class Transformer(nn.Module): class CausalTransformer(nn.Module):
def __init__( def __init__(
self, self,
*, *,
@@ -184,7 +187,7 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout), Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
])) ]))
@@ -211,7 +214,7 @@ class DiffusionPriorNetwork(nn.Module):
super().__init__() super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim)) self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = Transformer(**kwargs) self.causal_transformer = CausalTransformer(**kwargs)
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,