mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-21 18:44:20 +01:00
offer continuously parameterized time embedding for diffusion prior network, remove a hyperparameter that may trip up people, if not set correctly
This commit is contained in:
@@ -162,7 +162,6 @@ clip = CLIP(
|
|||||||
|
|
||||||
prior_network = DiffusionPriorNetwork(
|
prior_network = DiffusionPriorNetwork(
|
||||||
dim = 512,
|
dim = 512,
|
||||||
num_timesteps = 100,
|
|
||||||
depth = 6,
|
depth = 6,
|
||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8
|
heads = 8
|
||||||
@@ -251,7 +250,6 @@ loss.backward()
|
|||||||
|
|
||||||
prior_network = DiffusionPriorNetwork(
|
prior_network = DiffusionPriorNetwork(
|
||||||
dim = 512,
|
dim = 512,
|
||||||
num_timesteps = 100,
|
|
||||||
depth = 6,
|
depth = 6,
|
||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8
|
heads = 8
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||||
from einops_exts.torch import EinopsToAndFrom
|
from einops_exts.torch import EinopsToAndFrom
|
||||||
|
|
||||||
@@ -124,6 +125,43 @@ class PreNormResidual(nn.Module):
|
|||||||
def forward(self, x, **kwargs):
|
def forward(self, x, **kwargs):
|
||||||
return self.fn(self.norm(x), **kwargs) + x
|
return self.fn(self.norm(x), **kwargs) + x
|
||||||
|
|
||||||
|
# mlp
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_in,
|
||||||
|
dim_out,
|
||||||
|
*,
|
||||||
|
expansion_factor = 2.,
|
||||||
|
depth = 2,
|
||||||
|
norm = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(expansion_factor * dim_out)
|
||||||
|
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
|
||||||
|
|
||||||
|
layers = [nn.Sequential(
|
||||||
|
nn.Linear(dim_in, hidden_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
norm_fn()
|
||||||
|
)]
|
||||||
|
|
||||||
|
for _ in range(depth - 1):
|
||||||
|
layers.append(nn.Sequential(
|
||||||
|
nn.Linear(hidden_dim, hidden_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
norm_fn()
|
||||||
|
))
|
||||||
|
|
||||||
|
layers.append(nn.Linear(hidden_dim, dim_out))
|
||||||
|
self.net = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x.float())
|
||||||
|
|
||||||
|
# feedforward
|
||||||
|
|
||||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||||
inner_dim = int(mult * dim)
|
inner_dim = int(mult * dim)
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
@@ -134,6 +172,8 @@ def FeedForward(dim, mult = 4, dropout = 0.):
|
|||||||
nn.Linear(inner_dim, dim, bias = False)
|
nn.Linear(inner_dim, dim, bias = False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -235,11 +275,11 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
num_timesteps = 1000,
|
num_timesteps = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user