diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 6a16a14..694483d 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1127,11 +1127,12 @@ class SinusoidalPosEmb(nn.Module): self.dim = dim def forward(self, x): + dtype, device = x.dtype, x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb) - emb = rearrange(x.type_as(emb), 'i -> i 1') * rearrange(emb, 'j -> 1 j') - return torch.cat((emb.sin(), emb.cos()), dim = -1) + emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb) + emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') + return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype) class Block(nn.Module): def __init__( @@ -1626,6 +1627,7 @@ class Unet(nn.Module): # time conditioning + time = time.type_as(x) time_hiddens = self.to_time_hiddens(time) time_tokens = self.to_time_tokens(time_hiddens) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 731cfdb..bc68d1e 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.8' +__version__ = '0.16.9'