cast long as float before deriving sinusoidal pos emb

This commit is contained in:
Phil Wang
2022-07-05 18:01:22 -07:00
parent ee75515c7d
commit 1e4bb2bafb
2 changed files with 2 additions and 2 deletions

View File

@@ -1130,7 +1130,7 @@ class SinusoidalPosEmb(nn.Module):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
emb = rearrange(x.type_as(emb), 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
class Block(nn.Module):

View File

@@ -1 +1 @@
__version__ = '0.16.7'
__version__ = '0.16.8'