diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 2c17908..6491e2c 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -100,6 +100,9 @@ def eval_decorator(fn): return out return inner +def is_float_dtype(dtype): + return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)]) + def is_list_str(x): if not isinstance(x, (list, tuple)): return False @@ -968,6 +971,8 @@ class DiffusionPriorNetwork(nn.Module): Rearrange('b (n d) -> b n d', n = num_text_embeds) ) + self.continuous_embedded_time = not exists(num_timesteps) + self.to_time_embeds = nn.Sequential( nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP Rearrange('b (n d) -> b n d', n = num_time_embeds) @@ -1095,6 +1100,9 @@ class DiffusionPriorNetwork(nn.Module): # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # but let's just do it right + if self.continuous_embedded_time: + diffusion_timesteps = diffusion_timesteps.type(dtype) + time_embed = self.to_time_embeds(diffusion_timesteps) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) @@ -1538,6 +1546,8 @@ class SinusoidalPosEmb(nn.Module): def forward(self, x): dtype, device = x.dtype, x.device + assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type' + half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 8047ba3..192e7e0 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.10.7' +__version__ = '1.10.8'