add cosine sim for self attention as well, as a setting

This commit is contained in:
Phil Wang
2022-07-29 12:48:20 -07:00
parent 2d67d5821e
commit d167378401
2 changed files with 14 additions and 8 deletions

View File

@@ -701,11 +701,12 @@ class Attention(nn.Module):
dropout = 0., dropout = 0.,
causal = False, causal = False,
rotary_emb = None, rotary_emb = None,
pb_relax_alpha = 128 cosine_sim = True,
cosine_sim_scale = 16
): ):
super().__init__() super().__init__()
self.pb_relax_alpha = pb_relax_alpha self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1) self.cosine_sim = cosine_sim
self.heads = heads self.heads = heads
inner_dim = dim_head * heads inner_dim = dim_head * heads
@@ -745,6 +746,13 @@ class Attention(nn.Module):
k = torch.cat((nk, k), dim = -2) k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2) v = torch.cat((nv, v), dim = -2)
# whether to use cosine sim
if self.cosine_sim:
q, k = map(l2norm, (q, k))
q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
# calculate query / key similarities # calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k) sim = einsum('b h i d, b j d -> b h i j', q, k)
@@ -770,9 +778,6 @@ class Attention(nn.Module):
# attention # attention
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1) attn = sim.softmax(dim = -1)
attn = self.dropout(attn) attn = self.dropout(attn)
@@ -1604,6 +1609,7 @@ class Unet(nn.Module):
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
sparse_attn = False, sparse_attn = False,
cosine_sim_cross_attn = False, cosine_sim_cross_attn = False,
cosine_sim_self_attn = False,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False, cond_on_text_encodings = False,
max_text_len = 256, max_text_len = 256,
@@ -1724,7 +1730,7 @@ class Unet(nn.Module):
# attention related params # attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim = cosine_sim_self_attn)
self_attn = cast_tuple(self_attn, num_stages) self_attn = cast_tuple(self_attn, num_stages)

View File

@@ -1 +1 @@
__version__ = '1.4.2' __version__ = '1.4.3'