Compare commits

...

12 Commits

3 changed files with 255 additions and 73 deletions

View File

@@ -22,19 +22,11 @@ For all of you emailing me (there is a lot), the best way to contribute is throu
$ pip install dalle2-pytorch $ pip install dalle2-pytorch
``` ```
## CLI Usage (work in progress) ## Usage
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
## Training (for deep learning practitioners)
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already underway. To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already <a href="https://github.com/mlfoundations/open_clip">underway</a>.
This repository will demonstrate integration with `x-clip` for starters This repository will demonstrate integration with `x-clip` for starters
@@ -109,7 +101,7 @@ clip = CLIP(
unet = Unet( unet = Unet(
dim = 128, dim = 128,
image_embed_dim = 512, image_embed_dim = 512,
time_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults=(1, 2, 4, 8) dim_mults=(1, 2, 4, 8)
).cuda() ).cuda()
@@ -162,7 +154,6 @@ clip = CLIP(
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
dim = 512, dim = 512,
num_timesteps = 100,
depth = 6, depth = 6,
dim_head = 64, dim_head = 64,
heads = 8 heads = 8
@@ -251,7 +242,6 @@ loss.backward()
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
dim = 512, dim = 512,
num_timesteps = 100,
depth = 6, depth = 6,
dim_head = 64, dim_head = 64,
heads = 8 heads = 8
@@ -274,7 +264,7 @@ loss.backward()
unet = Unet( unet = Unet(
dim = 128, dim = 128,
image_embed_dim = 512, image_embed_dim = 512,
time_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults=(1, 2, 4, 8) dim_mults=(1, 2, 4, 8)
).cuda() ).cuda()
@@ -286,7 +276,7 @@ decoder = Decoder(
cond_drop_prob = 0.2 cond_drop_prob = 0.2
).cuda() ).cuda()
loss = decoder(images) loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward() loss.backward()
# do above for many steps # do above for many steps
@@ -308,6 +298,18 @@ Everything in this readme should run without error
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training. For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## CLI Usage (work in progress)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
## Training wrapper (wip)
Offer training wrappers
## Training CLI (wip) ## Training CLI (wip)
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a> <a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
@@ -317,7 +319,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon - [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work) - [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step - [x] make sure it works end to end to produce an output tensor, taking a single gradient step
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference) - [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature - [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper) - [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab

View File

@@ -7,9 +7,12 @@ import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom from einops_exts.torch import EinopsToAndFrom
from kornia.filters import filter2d
from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.tokenizer import tokenizer
# use x-clip # use x-clip
@@ -115,25 +118,72 @@ class ChanRMSNorm(RMSNorm):
inv_norm = torch.rsqrt(squared_sum + self.eps) inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
class PreNormResidual(nn.Module): class Residual(nn.Module):
def __init__(self, dim, fn): def __init__(self, fn):
super().__init__() super().__init__()
self.fn = fn self.fn = fn
self.norm = RMSNorm(dim)
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x return self.fn(x, **kwargs) + x
# mlp
class MLP(nn.Module):
def __init__(
self,
dim_in,
dim_out,
*,
expansion_factor = 2.,
depth = 2,
norm = False,
):
super().__init__()
hidden_dim = int(expansion_factor * dim_out)
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
layers = [nn.Sequential(
nn.Linear(dim_in, hidden_dim),
nn.SiLU(),
norm_fn()
)]
for _ in range(depth - 1):
layers.append(nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
norm_fn()
))
layers.append(nn.Linear(hidden_dim, dim_out))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x.float())
# feedforward
class SwiGLU(nn.Module):
""" used successfully in https://arxiv.org/abs/2204.0231 """
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return x * F.silu(gate)
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
""" post-activation norm https://arxiv.org/abs/2110.09456 """
def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(mult * dim) inner_dim = int(mult * dim)
return nn.Sequential( return nn.Sequential(
RMSNorm(dim), RMSNorm(dim),
nn.Linear(dim, inner_dim, bias = False), nn.Linear(dim, inner_dim * 2, bias = False),
nn.GELU(), SwiGLU(),
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False) nn.Linear(inner_dim, dim, bias = False)
) )
# attention
class Attention(nn.Module): class Attention(nn.Module):
def __init__( def __init__(
self, self,
@@ -189,6 +239,7 @@ class Attention(nn.Module):
sim = sim - sim.amax(dim = -1, keepdim = True) sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1) attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b j d -> b h i d', attn, v) out = einsum('b h i j, b j d -> b h i d', attn, v)
@@ -235,27 +286,26 @@ class DiffusionPriorNetwork(nn.Module):
def __init__( def __init__(
self, self,
dim, dim,
num_timesteps = 1000, num_timesteps = None,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim)) self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs) self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
x,
*args, *args,
cond_scale = 1., cond_scale = 1.,
**kwargs **kwargs
): ):
logits = self.forward(x, *args, **kwargs) logits = self.forward(*args, **kwargs)
if cond_scale == 1: if cond_scale == 1:
return logits return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs) null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
@@ -275,8 +325,15 @@ class DiffusionPriorNetwork(nn.Module):
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d') text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if exists(mask): if exists(mask):
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query not_all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps) time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d') time_embed = rearrange(time_embed, 'b d -> b 1 d')
@@ -514,6 +571,17 @@ def Upsample(dim):
def Downsample(dim): def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1) return nn.Conv2d(dim, dim, 4, 2, 1)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
@@ -541,10 +609,17 @@ class ConvNextBlock(nn.Module):
super().__init__() super().__init__()
need_projection = dim != dim_out need_projection = dim != dim_out
self.mlp = nn.Sequential( self.cross_attn = None
nn.GELU(),
nn.Linear(cond_dim, dim) if exists(cond_dim):
) if exists(cond_dim) else None self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim,
context_dim = cond_dim
)
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim) self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
@@ -561,21 +636,82 @@ class ConvNextBlock(nn.Module):
def forward(self, x, cond = None): def forward(self, x, cond = None):
h = self.ds_conv(x) h = self.ds_conv(x)
if exists(self.mlp): if exists(self.cross_attn):
assert exists(cond) assert exists(cond)
condition = self.mlp(cond) h = self.cross_attn(h, context = cond) + h
h = h + rearrange(condition, 'b c -> b c 1 1')
h = self.net(h) h = self.net(h)
return h + self.res_conv(x) return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
dropout = 0.,
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.norm = RMSNorm(dim)
self.norm_context = RMSNorm(context_dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Unet(nn.Module): class Unet(nn.Module):
def __init__( def __init__(
self, self,
dim, dim,
*, *,
image_embed_dim, image_embed_dim,
time_dim = None, cond_dim = None,
num_image_tokens = 4,
out_dim = None, out_dim = None,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
channels = 3, channels = 3,
@@ -586,18 +722,31 @@ class Unet(nn.Module):
dims = [channels, *map(lambda m: dim * m, dim_mults)] dims = [channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
time_dim = default(time_dim, dim) # time, image embeddings, and optional text encoding
cond_dim = default(cond_dim, dim)
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim), SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4), nn.Linear(dim, dim * 4),
nn.GELU(), nn.GELU(),
nn.Linear(dim * 4, dim) nn.Linear(dim * 4, cond_dim),
Rearrange('b d -> b 1 d')
) )
self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim)) self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
cond_dim = time_dim + image_embed_dim self.text_to_cond = nn.LazyLinear(cond_dim)
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
# layers
self.downs = nn.ModuleList([]) self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([]) self.ups = nn.ModuleList([])
@@ -607,7 +756,7 @@ class Unet(nn.Module):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0), ConvNextBlock(dim_in, dim_out, norm = ind != 0),
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim), ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
Downsample(dim_out) if not is_last else nn.Identity() Downsample(dim_out) if not is_last else nn.Identity()
])) ]))
@@ -615,7 +764,7 @@ class Unet(nn.Module):
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim))) self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
@@ -635,17 +784,16 @@ class Unet(nn.Module):
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
x,
*args, *args,
cond_scale = 1., cond_scale = 1.,
**kwargs **kwargs
): ):
logits = self.forward(x, *args, **kwargs) logits = self.forward(*args, **kwargs)
if cond_scale == 1: if cond_scale == 1:
return logits return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs) null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
@@ -658,37 +806,59 @@ class Unet(nn.Module):
cond_drop_prob = 0. cond_drop_prob = 0.
): ):
batch_size, device = x.shape[0], x.device batch_size, device = x.shape[0], x.device
t = self.time_mlp(time) time_tokens = self.time_mlp(time)
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device) cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
# mask out image embedding depending on condition dropout # mask out image embedding depending on condition dropout
# for classifier free guidance # for classifier free guidance
image_embed = torch.where( image_tokens = self.image_to_cond(image_embed)
rearrange(cond_prob_mask, 'b -> b 1'),
image_embed, image_tokens = torch.where(
rearrange(self.null_image_embed, 'd -> 1 d') cond_prob_mask,
image_tokens,
self.null_image_embed
) )
t = torch.cat((t, image_embed), dim = -1) # take care of text encodings (optional)
if exists(text_encodings):
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
text_tokens,
self.null_text_embed
)
# main conditioning tokens (c)
c = torch.cat((time_tokens, image_tokens), dim = -2)
# text and image conditioning tokens (mid_c)
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
# go through the layers of the unet, down and up
hiddens = [] hiddens = []
for convnext, convnext2, downsample in self.downs: for convnext, convnext2, downsample in self.downs:
x = convnext(x, t) x = convnext(x, c)
x = convnext2(x, t) x = convnext2(x, c)
hiddens.append(x) hiddens.append(x)
x = downsample(x) x = downsample(x)
x = self.mid_block1(x, t) x = self.mid_block1(x, mid_c)
x = self.mid_attn(x) x = self.mid_attn(x)
x = self.mid_block2(x, t) x = self.mid_block2(x, mid_c)
for convnext, convnext2, upsample in self.ups: for convnext, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1) x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, t) x = convnext(x, c)
x = convnext2(x, t) x = convnext2(x, c)
x = upsample(x) x = upsample(x)
return self.final_conv(x) return self.final_conv(x)
@@ -749,6 +919,10 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
def get_image_embed(self, image): def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image) image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0] image_cls = image_encoding[:, 0]
@@ -776,8 +950,8 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.): def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale)) x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
if clip_denoised: if clip_denoised:
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
@@ -786,31 +960,32 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False): def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised) model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise) noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop(self, shape, image_embed, cond_scale = 1): def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
img = torch.randn(shape, device=device) img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale) img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
return img return img
@torch.no_grad() @torch.no_grad()
def sample(self, image_embed, cond_scale = 1.): def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0] batch_size = image_embed.shape[0]
image_size = self.image_size image_size = self.image_size
channels = self.channels channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale) text_encodings = self.get_text_encodings(text) if exists(text) else None
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
@@ -820,7 +995,7 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
) )
def p_losses(self, x_start, t, *, image_embed, noise = None): def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
@@ -829,6 +1004,7 @@ class Decoder(nn.Module):
x_noisy, x_noisy,
t, t,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings,
cond_drop_prob = self.cond_drop_prob cond_drop_prob = self.cond_drop_prob
) )
@@ -841,14 +1017,16 @@ class Decoder(nn.Module):
return loss return loss
def forward(self, image): def forward(self, image, text = None):
b, device, img_size, = image.shape[0], image.device, self.image_size b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels) check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
loss = self.p_losses(image, times, image_embed = image_embed) image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) else None
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
return loss return loss
# main class # main class
@@ -864,11 +1042,12 @@ class DALLE2(nn.Module):
super().__init__() super().__init__()
assert isinstance(prior, DiffusionPrior) assert isinstance(prior, DiffusionPrior)
assert isinstance(decoder, Decoder) assert isinstance(decoder, Decoder)
self.prior = prior.eval() self.prior = prior
self.decoder = decoder.eval() self.decoder = decoder
self.prior_num_samples = prior_num_samples self.prior_num_samples = prior_num_samples
@torch.no_grad() @torch.no_grad()
@eval_decorator
def forward( def forward(
self, self,
text, text,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.7', version = '0.0.14',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',
@@ -25,6 +25,7 @@ setup(
'click', 'click',
'einops>=0.4', 'einops>=0.4',
'einops-exts>=0.0.3', 'einops-exts>=0.0.3',
'kornia>=0.5.4',
'pillow', 'pillow',
'torch>=1.10', 'torch>=1.10',
'torchvision', 'torchvision',