mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
sinusoidal embed time embeddings for diffusion prior as well, for continuous version
This commit is contained in:
@@ -706,7 +706,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user