mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
cast long as float before deriving sinusoidal pos emb
This commit is contained in:
@@ -1130,7 +1130,7 @@ class SinusoidalPosEmb(nn.Module):
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
|
||||
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
||||
emb = rearrange(x.type_as(emb), 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
||||
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.16.7'
|
||||
__version__ = '0.16.8'
|
||||
|
||||
Reference in New Issue
Block a user