Compare commits

...

14 Commits

Author SHA1 Message Date
Phil Wang
a95de92964 fix bug with @jihoonerd 2022-04-15 06:40:46 -07:00
Phil Wang
5b4ee09625 ideation 2022-04-14 13:48:01 -07:00
Phil Wang
6e27f617f1 use t5 relative positional bias in prior network causal transformer, since it makes more sense than rotary embeddings 2022-04-14 12:01:09 -07:00
Phil Wang
9f55c24db6 allow for decoder conditioning with the text encodings from CLIP, if it is passed in. use lazy linear to avoid researchers having to worry about text encoding dimensions, but remove later if it does not work well 2022-04-14 11:46:45 -07:00
Phil Wang
69e822b7f8 "project management" 2022-04-14 10:20:37 -07:00
Phil Wang
23c401a5d5 use the eval decorator 2022-04-14 10:13:43 -07:00
Phil Wang
68e9883f59 use cross attention for conditioning unet based on image embedding tokens (which opens up the door on conditioning on text encodings as well 2022-04-14 10:10:04 -07:00
Phil Wang
95b018374a start using swish glu everywhere, given success of PaLM 2022-04-14 09:34:32 -07:00
Phil Wang
8b5c2385b0 better naming 2022-04-14 09:24:31 -07:00
Phil Wang
f2c52d8239 fix bug with classifier free guidance for prior network, even though it seems it may not be used 2022-04-14 09:21:51 -07:00
Phil Wang
97e951221b bring in blur, as it will be used somewhere in the cascading DDPM in the decoder eventually, once i figure it out 2022-04-14 09:16:09 -07:00
Phil Wang
e1b0c140f1 cleanup readme 2022-04-14 08:51:22 -07:00
Phil Wang
5989569a44 link to OpenCLIP effort 2022-04-14 08:31:15 -07:00
Phil Wang
82464d7bd3 per-fect 2022-04-14 08:30:07 -07:00
3 changed files with 273 additions and 74 deletions

View File

@@ -22,19 +22,11 @@ For all of you emailing me (there is a lot), the best way to contribute is throu
$ pip install dalle2-pytorch
```
## CLI Usage (work in progress)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
## Training (for deep learning practitioners)
## Usage
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already underway.
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already <a href="https://github.com/mlfoundations/open_clip">underway</a>.
This repository will demonstrate integration with `x-clip` for starters
@@ -109,7 +101,7 @@ clip = CLIP(
unet = Unet(
dim = 128,
image_embed_dim = 512,
time_dim = 128,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
@@ -272,7 +264,7 @@ loss.backward()
unet = Unet(
dim = 128,
image_embed_dim = 512,
time_dim = 128,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
@@ -284,7 +276,7 @@ decoder = Decoder(
cond_drop_prob = 0.2
).cuda()
loss = decoder(images)
loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
@@ -306,6 +298,18 @@ Everything in this readme should run without error
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## CLI Usage (work in progress)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
## Training wrapper (wip)
Offer training wrappers
## Training CLI (wip)
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
@@ -315,12 +319,13 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [ ] train on a toy task, offer in colab
- [ ] add attention to unet - apply some personal tricks with efficient attention
- [ ] figure out the big idea behind latent diffusion and what can be ported over
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007
## Citations

View File

@@ -11,6 +11,8 @@ from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters import filter2d
from dalle2_pytorch.tokenizer import tokenizer
# use x-clip
@@ -116,14 +118,13 @@ class ChanRMSNorm(RMSNorm):
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.norm = RMSNorm(dim)
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x
return self.fn(x, **kwargs) + x
# mlp
@@ -160,14 +161,61 @@ class MLP(nn.Module):
def forward(self, x):
return self.net(x.float())
# relative positional bias for causal transformer
class RelPosBias(nn.Module):
def __init__(
self,
heads = 8,
num_buckets = 32,
max_distance = 128,
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(
relative_position,
num_buckets = 32,
max_distance = 128
):
n = -relative_position
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
return torch.where(is_small, n, val_if_large)
def forward(self, i, j, *, device):
q_pos = torch.arange(i, dtype = torch.long, device = device)
k_pos = torch.arange(j, dtype = torch.long, device = device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
return rearrange(values, 'i j h -> h i j')
# feedforward
def FeedForward(dim, mult = 4, dropout = 0.):
class SwiGLU(nn.Module):
""" used successfully in https://arxiv.org/abs/2204.0231 """
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return x * F.silu(gate)
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
""" post-activation norm https://arxiv.org/abs/2110.09456 """
inner_dim = int(mult * dim)
return nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, inner_dim, bias = False),
nn.GELU(),
nn.Linear(dim, inner_dim * 2, bias = False),
SwiGLU(),
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False)
)
@@ -198,7 +246,7 @@ class Attention(nn.Module):
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, mask = None):
def forward(self, x, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
@@ -215,6 +263,14 @@ class Attention(nn.Module):
q = q * self.scale
sim = einsum('b h i d, b j d -> b h i j', q, k)
# relative positional encoding (T5 style)
if exists(attn_bias):
sim = sim + attn_bias
# masking
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
@@ -227,8 +283,13 @@ class Attention(nn.Module):
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, max_neg_value)
# attention
sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate values
out = einsum('b h i j, b j d -> b h i d', attn, v)
@@ -249,7 +310,7 @@ class CausalTransformer(nn.Module):
ff_dropout = 0.
):
super().__init__()
# todo - bring in rotary embeddings or alibi
self.rel_pos_bias = RelPosBias(heads = heads)
self.layers = nn.ModuleList([])
for _ in range(depth):
@@ -265,8 +326,12 @@ class CausalTransformer(nn.Module):
x,
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
):
n, device = x.shape[1], x.device
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = ff(x) + x
return self.norm(x)
@@ -285,17 +350,16 @@ class DiffusionPriorNetwork(nn.Module):
def forward_with_cond_scale(
self,
x,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(x, *args, **kwargs)
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -315,8 +379,15 @@ class DiffusionPriorNetwork(nn.Module):
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if exists(mask):
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
not_all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d')
@@ -379,7 +450,7 @@ class DiffusionPrior(nn.Module):
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
@@ -554,6 +625,17 @@ def Upsample(dim):
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -581,10 +663,17 @@ class ConvNextBlock(nn.Module):
super().__init__()
need_projection = dim != dim_out
self.mlp = nn.Sequential(
nn.GELU(),
nn.Linear(cond_dim, dim)
) if exists(cond_dim) else None
self.cross_attn = None
if exists(cond_dim):
self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim,
context_dim = cond_dim
)
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
@@ -601,21 +690,82 @@ class ConvNextBlock(nn.Module):
def forward(self, x, cond = None):
h = self.ds_conv(x)
if exists(self.mlp):
if exists(self.cross_attn):
assert exists(cond)
condition = self.mlp(cond)
h = h + rearrange(condition, 'b c -> b c 1 1')
h = self.cross_attn(h, context = cond) + h
h = self.net(h)
return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
dropout = 0.,
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.norm = RMSNorm(dim)
self.norm_context = RMSNorm(context_dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Unet(nn.Module):
def __init__(
self,
dim,
*,
image_embed_dim,
time_dim = None,
cond_dim = None,
num_image_tokens = 4,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
@@ -626,18 +776,31 @@ class Unet(nn.Module):
dims = [channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
time_dim = default(time_dim, dim)
# time, image embeddings, and optional text encoding
cond_dim = default(cond_dim, dim)
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
nn.Linear(dim * 4, cond_dim),
Rearrange('b d -> b 1 d')
)
self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim))
self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
cond_dim = time_dim + image_embed_dim
self.text_to_cond = nn.LazyLinear(cond_dim)
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
@@ -647,7 +810,7 @@ class Unet(nn.Module):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0),
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
@@ -655,7 +818,7 @@ class Unet(nn.Module):
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
@@ -675,17 +838,16 @@ class Unet(nn.Module):
def forward_with_cond_scale(
self,
x,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(x, *args, **kwargs)
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -698,37 +860,59 @@ class Unet(nn.Module):
cond_drop_prob = 0.
):
batch_size, device = x.shape[0], x.device
t = self.time_mlp(time)
time_tokens = self.time_mlp(time)
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_embed = torch.where(
rearrange(cond_prob_mask, 'b -> b 1'),
image_embed,
rearrange(self.null_image_embed, 'd -> 1 d')
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
t = torch.cat((t, image_embed), dim = -1)
# take care of text encodings (optional)
if exists(text_encodings):
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
text_tokens,
self.null_text_embed
)
# main conditioning tokens (c)
c = torch.cat((time_tokens, image_tokens), dim = -2)
# text and image conditioning tokens (mid_c)
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
# go through the layers of the unet, down and up
hiddens = []
for convnext, convnext2, downsample in self.downs:
x = convnext(x, t)
x = convnext2(x, t)
x = convnext(x, c)
x = convnext2(x, c)
hiddens.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_block1(x, mid_c)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
x = self.mid_block2(x, mid_c)
for convnext, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, t)
x = convnext2(x, t)
x = convnext(x, c)
x = convnext2(x, c)
x = upsample(x)
return self.final_conv(x)
@@ -757,7 +941,7 @@ class Decoder(nn.Module):
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
@@ -789,6 +973,10 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
@@ -816,8 +1004,8 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
if clip_denoised:
x_recon.clamp_(-1., 1.)
@@ -826,31 +1014,32 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
return img
@torch.no_grad()
def sample(self, image_embed, cond_scale = 1.):
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
text_encodings = self.get_text_encodings(text) if exists(text) else None
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -860,7 +1049,7 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, *, image_embed, noise = None):
def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
@@ -869,6 +1058,7 @@ class Decoder(nn.Module):
x_noisy,
t,
image_embed = image_embed,
text_encodings = text_encodings,
cond_drop_prob = self.cond_drop_prob
)
@@ -881,14 +1071,16 @@ class Decoder(nn.Module):
return loss
def forward(self, image):
def forward(self, image, text = None):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
loss = self.p_losses(image, times, image_embed = image_embed)
image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) else None
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
return loss
# main class
@@ -904,11 +1096,12 @@ class DALLE2(nn.Module):
super().__init__()
assert isinstance(prior, DiffusionPrior)
assert isinstance(decoder, Decoder)
self.prior = prior.eval()
self.decoder = decoder.eval()
self.prior = prior
self.decoder = decoder
self.prior_num_samples = prior_num_samples
@torch.no_grad()
@eval_decorator
def forward(
self,
text,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.8',
version = '0.0.16',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -25,6 +25,7 @@ setup(
'click',
'einops>=0.4',
'einops-exts>=0.0.3',
'kornia>=0.5.4',
'pillow',
'torch>=1.10',
'torchvision',