diff --git a/README.md b/README.md index 58a6a2c..6eafe96 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,6 @@ clip = CLIP( prior_network = DiffusionPriorNetwork( dim = 512, - num_timesteps = 100, depth = 6, dim_head = 64, heads = 8 @@ -251,7 +250,6 @@ loss.backward() prior_network = DiffusionPriorNetwork( dim = 512, - num_timesteps = 100, depth = 6, dim_head = 64, heads = 8 diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 921d0f2..41b0db7 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat +from einops.layers.torch import Rearrange from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts.torch import EinopsToAndFrom @@ -124,6 +125,43 @@ class PreNormResidual(nn.Module): def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) + x +# mlp + +class MLP(nn.Module): + def __init__( + self, + dim_in, + dim_out, + *, + expansion_factor = 2., + depth = 2, + norm = False, + ): + super().__init__() + hidden_dim = int(expansion_factor * dim_out) + norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity() + + layers = [nn.Sequential( + nn.Linear(dim_in, hidden_dim), + nn.SiLU(), + norm_fn() + )] + + for _ in range(depth - 1): + layers.append(nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + norm_fn() + )) + + layers.append(nn.Linear(hidden_dim, dim_out)) + self.net = nn.Sequential(*layers) + + def forward(self, x): + return self.net(x.float()) + +# feedforward + def FeedForward(dim, mult = 4, dropout = 0.): inner_dim = int(mult * dim) return nn.Sequential( @@ -134,6 +172,8 @@ def FeedForward(dim, mult = 4, dropout = 0.): nn.Linear(inner_dim, dim, bias = False) ) +# attention + class Attention(nn.Module): def __init__( self, @@ -235,11 +275,11 @@ class DiffusionPriorNetwork(nn.Module): def __init__( self, dim, - num_timesteps = 1000, + num_timesteps = None, **kwargs ): super().__init__() - self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP + self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP self.learned_query = nn.Parameter(torch.randn(dim)) self.causal_transformer = CausalTransformer(dim = dim, **kwargs) diff --git a/setup.py b/setup.py index c514016..9d71c34 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.7', + version = '0.0.8', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',