Compare commits

..

1 Commits
1.4.3 ... 1.4.1

2 changed files with 9 additions and 15 deletions

View File

@@ -580,7 +580,7 @@ class ChanLayerNorm(nn.Module):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g
return (x - mean) * (var + self.eps).rsqrt() * self.g
class Residual(nn.Module):
def __init__(self, fn):
@@ -701,12 +701,11 @@ class Attention(nn.Module):
dropout = 0.,
causal = False,
rotary_emb = None,
cosine_sim = True,
cosine_sim_scale = 16
pb_relax_alpha = 128
):
super().__init__()
self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
self.cosine_sim = cosine_sim
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.heads = heads
inner_dim = dim_head * heads
@@ -746,13 +745,6 @@ class Attention(nn.Module):
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# whether to use cosine sim
if self.cosine_sim:
q, k = map(l2norm, (q, k))
q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
# calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k)
@@ -778,6 +770,9 @@ class Attention(nn.Module):
# attention
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
@@ -1609,7 +1604,6 @@ class Unet(nn.Module):
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
sparse_attn = False,
cosine_sim_cross_attn = False,
cosine_sim_self_attn = False,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
max_text_len = 256,
@@ -1730,7 +1724,7 @@ class Unet(nn.Module):
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim = cosine_sim_self_attn)
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
self_attn = cast_tuple(self_attn, num_stages)

View File

@@ -1 +1 @@
__version__ = '1.4.3'
__version__ = '1.4.1'