Compare commits

...

6 Commits
0.0.7 ... 0.0.9

3 changed files with 83 additions and 22 deletions

View File

@@ -22,19 +22,11 @@ For all of you emailing me (there is a lot), the best way to contribute is throu
$ pip install dalle2-pytorch $ pip install dalle2-pytorch
``` ```
## CLI Usage (work in progress) ## Usage
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
## Training (for deep learning practitioners)
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already underway. To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already <a href="https://github.com/mlfoundations/open_clip">underway</a>.
This repository will demonstrate integration with `x-clip` for starters This repository will demonstrate integration with `x-clip` for starters
@@ -162,7 +154,6 @@ clip = CLIP(
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
dim = 512, dim = 512,
num_timesteps = 100,
depth = 6, depth = 6,
dim_head = 64, dim_head = 64,
heads = 8 heads = 8
@@ -251,7 +242,6 @@ loss.backward()
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
dim = 512, dim = 512,
num_timesteps = 100,
depth = 6, depth = 6,
dim_head = 64, dim_head = 64,
heads = 8 heads = 8
@@ -308,6 +298,18 @@ Everything in this readme should run without error
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training. For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## CLI Usage (work in progress)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
## Training wrapper (wip)
Offer training wrappers
## Training CLI (wip) ## Training CLI (wip)
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a> <a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>

View File

@@ -7,9 +7,12 @@ import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom from einops_exts.torch import EinopsToAndFrom
from kornia.filters import filter2d
from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.tokenizer import tokenizer
# use x-clip # use x-clip
@@ -124,6 +127,43 @@ class PreNormResidual(nn.Module):
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x return self.fn(self.norm(x), **kwargs) + x
# mlp
class MLP(nn.Module):
def __init__(
self,
dim_in,
dim_out,
*,
expansion_factor = 2.,
depth = 2,
norm = False,
):
super().__init__()
hidden_dim = int(expansion_factor * dim_out)
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
layers = [nn.Sequential(
nn.Linear(dim_in, hidden_dim),
nn.SiLU(),
norm_fn()
)]
for _ in range(depth - 1):
layers.append(nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
norm_fn()
))
layers.append(nn.Linear(hidden_dim, dim_out))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x.float())
# feedforward
def FeedForward(dim, mult = 4, dropout = 0.): def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(mult * dim) inner_dim = int(mult * dim)
return nn.Sequential( return nn.Sequential(
@@ -134,6 +174,8 @@ def FeedForward(dim, mult = 4, dropout = 0.):
nn.Linear(inner_dim, dim, bias = False) nn.Linear(inner_dim, dim, bias = False)
) )
# attention
class Attention(nn.Module): class Attention(nn.Module):
def __init__( def __init__(
self, self,
@@ -235,27 +277,26 @@ class DiffusionPriorNetwork(nn.Module):
def __init__( def __init__(
self, self,
dim, dim,
num_timesteps = 1000, num_timesteps = None,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim)) self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs) self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
x,
*args, *args,
cond_scale = 1., cond_scale = 1.,
**kwargs **kwargs
): ):
logits = self.forward(x, *args, **kwargs) logits = self.forward(*args, **kwargs)
if cond_scale == 1: if cond_scale == 1:
return logits return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs) null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
@@ -275,8 +316,15 @@ class DiffusionPriorNetwork(nn.Module):
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d') text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if exists(mask): if exists(mask):
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(all_masked_out, 'b -> b 1')), dim = 1)
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps) time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d') time_embed = rearrange(time_embed, 'b d -> b 1 d')
@@ -635,17 +683,16 @@ class Unet(nn.Module):
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
x,
*args, *args,
cond_scale = 1., cond_scale = 1.,
**kwargs **kwargs
): ):
logits = self.forward(x, *args, **kwargs) logits = self.forward(*args, **kwargs)
if cond_scale == 1: if cond_scale == 1:
return logits return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs) null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
@@ -693,6 +740,17 @@ class Unet(nn.Module):
return self.final_conv(x) return self.final_conv(x)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(
self, self,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.7', version = '0.0.9',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',
@@ -25,6 +25,7 @@ setup(
'click', 'click',
'einops>=0.4', 'einops>=0.4',
'einops-exts>=0.0.3', 'einops-exts>=0.0.3',
'kornia>=0.5.4',
'pillow', 'pillow',
'torch>=1.10', 'torch>=1.10',
'torchvision', 'torchvision',