Compare commits

...

3 Commits

Author SHA1 Message Date
Phil Wang
f33453df9f debugging with Aidan 2022-07-05 18:22:43 -07:00
Phil Wang
1e4bb2bafb cast long as float before deriving sinusoidal pos emb 2022-07-05 18:01:22 -07:00
Phil Wang
ee75515c7d remove forcing of softmax in f32, in case it is interfering with deepspeed 2022-07-05 16:53:58 -07:00
2 changed files with 7 additions and 5 deletions

View File

@@ -704,7 +704,7 @@ class Attention(nn.Module):
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate values
@@ -1127,11 +1127,12 @@ class SinusoidalPosEmb(nn.Module):
self.dim = dim
def forward(self, x):
dtype, device = x.dtype, x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
class Block(nn.Module):
def __init__(
@@ -1272,7 +1273,7 @@ class CrossAttention(nn.Module):
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1626,6 +1627,7 @@ class Unet(nn.Module):
# time conditioning
time = time.type_as(x)
time_hiddens = self.to_time_hiddens(time)
time_tokens = self.to_time_tokens(time_hiddens)

View File

@@ -1 +1 @@
__version__ = '0.16.6'
__version__ = '0.16.9'