debugging with Aidan

This commit is contained in:
Phil Wang
2022-07-05 18:22:43 -07:00
parent 1e4bb2bafb
commit f33453df9f
2 changed files with 6 additions and 4 deletions

View File

@@ -1127,11 +1127,12 @@ class SinusoidalPosEmb(nn.Module):
self.dim = dim self.dim = dim
def forward(self, x): def forward(self, x):
dtype, device = x.dtype, x.device
half_dim = self.dim // 2 half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb) emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
emb = rearrange(x.type_as(emb), 'i -> i 1') * rearrange(emb, 'j -> 1 j') emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1) return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
class Block(nn.Module): class Block(nn.Module):
def __init__( def __init__(
@@ -1626,6 +1627,7 @@ class Unet(nn.Module):
# time conditioning # time conditioning
time = time.type_as(x)
time_hiddens = self.to_time_hiddens(time) time_hiddens = self.to_time_hiddens(time)
time_tokens = self.to_time_tokens(time_hiddens) time_tokens = self.to_time_tokens(time_hiddens)

View File

@@ -1 +1 @@
__version__ = '0.16.8' __version__ = '0.16.9'