mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
debugging with Aidan
This commit is contained in:
@@ -1127,11 +1127,12 @@ class SinusoidalPosEmb(nn.Module):
|
|||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
dtype, device = x.dtype, x.device
|
||||||
half_dim = self.dim // 2
|
half_dim = self.dim // 2
|
||||||
emb = math.log(10000) / (half_dim - 1)
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
|
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
||||||
emb = rearrange(x.type_as(emb), 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
||||||
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1626,6 +1627,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# time conditioning
|
# time conditioning
|
||||||
|
|
||||||
|
time = time.type_as(x)
|
||||||
time_hiddens = self.to_time_hiddens(time)
|
time_hiddens = self.to_time_hiddens(time)
|
||||||
|
|
||||||
time_tokens = self.to_time_tokens(time_hiddens)
|
time_tokens = self.to_time_tokens(time_hiddens)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.16.8'
|
__version__ = '0.16.9'
|
||||||
|
|||||||
Reference in New Issue
Block a user