mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix a dtype conversion issue for the diffusion timesteps in the diffusion prior, thanks to @JiaHeng-DLUT
This commit is contained in:
@@ -100,6 +100,9 @@ def eval_decorator(fn):
|
||||
return out
|
||||
return inner
|
||||
|
||||
def is_float_dtype(dtype):
|
||||
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
|
||||
|
||||
def is_list_str(x):
|
||||
if not isinstance(x, (list, tuple)):
|
||||
return False
|
||||
@@ -968,6 +971,8 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
Rearrange('b (n d) -> b n d', n = num_text_embeds)
|
||||
)
|
||||
|
||||
self.continuous_embedded_time = not exists(num_timesteps)
|
||||
|
||||
self.to_time_embeds = nn.Sequential(
|
||||
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
Rearrange('b (n d) -> b n d', n = num_time_embeds)
|
||||
@@ -1095,6 +1100,9 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
if self.continuous_embedded_time:
|
||||
diffusion_timesteps = diffusion_timesteps.type(dtype)
|
||||
|
||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
@@ -1538,6 +1546,8 @@ class SinusoidalPosEmb(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
dtype, device = x.dtype, x.device
|
||||
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
|
||||
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
||||
|
||||
Reference in New Issue
Block a user