mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-22 02:54:20 +01:00
prepare non-causal attention, for use in the unet in the decoder
This commit is contained in:
@@ -116,12 +116,14 @@ class Attention(nn.Module):
|
|||||||
dim,
|
dim,
|
||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
dropout = 0.
|
dropout = 0.,
|
||||||
|
causal = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
|
self.causal = causal
|
||||||
self.norm = RMSNorm(dim)
|
self.norm = RMSNorm(dim)
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
@@ -154,6 +156,7 @@ class Attention(nn.Module):
|
|||||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||||
sim = sim.masked_fill(~mask, max_neg_value)
|
sim = sim.masked_fill(~mask, max_neg_value)
|
||||||
|
|
||||||
|
if self.causal:
|
||||||
causal_mask = torch.ones((n, n), dtype = torch.bool, device = device).triu(1)
|
causal_mask = torch.ones((n, n), dtype = torch.bool, device = device).triu(1)
|
||||||
sim = sim.masked_fill(causal_mask, max_neg_value)
|
sim = sim.masked_fill(causal_mask, max_neg_value)
|
||||||
|
|
||||||
@@ -165,7 +168,7 @@ class Attention(nn.Module):
|
|||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
class CausalTransformer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -184,7 +187,7 @@ class Transformer(nn.Module):
|
|||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
self.layers.append(nn.ModuleList([
|
self.layers.append(nn.ModuleList([
|
||||||
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
|
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout),
|
||||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
@@ -211,7 +214,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||||
self.causal_transformer = Transformer(**kwargs)
|
self.causal_transformer = CausalTransformer(**kwargs)
|
||||||
|
|
||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user