offer continuously parameterized time embedding for diffusion prior network, remove a hyperparameter that may trip up people, if not set correctly

This commit is contained in:
Phil Wang
2022-04-14 08:28:11 -07:00
parent 7e93b9d3c8
commit 7fb3f695d5
3 changed files with 43 additions and 5 deletions

View File

@@ -162,7 +162,6 @@ clip = CLIP(
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
dim = 512, dim = 512,
num_timesteps = 100,
depth = 6, depth = 6,
dim_head = 64, dim_head = 64,
heads = 8 heads = 8
@@ -251,7 +250,6 @@ loss.backward()
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
dim = 512, dim = 512,
num_timesteps = 100,
depth = 6, depth = 6,
dim_head = 64, dim_head = 64,
heads = 8 heads = 8

View File

@@ -7,6 +7,7 @@ import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom from einops_exts.torch import EinopsToAndFrom
@@ -124,6 +125,43 @@ class PreNormResidual(nn.Module):
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x return self.fn(self.norm(x), **kwargs) + x
# mlp
class MLP(nn.Module):
def __init__(
self,
dim_in,
dim_out,
*,
expansion_factor = 2.,
depth = 2,
norm = False,
):
super().__init__()
hidden_dim = int(expansion_factor * dim_out)
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
layers = [nn.Sequential(
nn.Linear(dim_in, hidden_dim),
nn.SiLU(),
norm_fn()
)]
for _ in range(depth - 1):
layers.append(nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
norm_fn()
))
layers.append(nn.Linear(hidden_dim, dim_out))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x.float())
# feedforward
def FeedForward(dim, mult = 4, dropout = 0.): def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(mult * dim) inner_dim = int(mult * dim)
return nn.Sequential( return nn.Sequential(
@@ -134,6 +172,8 @@ def FeedForward(dim, mult = 4, dropout = 0.):
nn.Linear(inner_dim, dim, bias = False) nn.Linear(inner_dim, dim, bias = False)
) )
# attention
class Attention(nn.Module): class Attention(nn.Module):
def __init__( def __init__(
self, self,
@@ -235,11 +275,11 @@ class DiffusionPriorNetwork(nn.Module):
def __init__( def __init__(
self, self,
dim, dim,
num_timesteps = 1000, num_timesteps = None,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim)) self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs) self.causal_transformer = CausalTransformer(dim = dim, **kwargs)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.7', version = '0.0.8',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',