mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c400d8758c | ||
|
|
bece206699 | ||
|
|
5b4ee09625 | ||
|
|
6e27f617f1 | ||
|
|
9f55c24db6 | ||
|
|
69e822b7f8 | ||
|
|
23c401a5d5 | ||
|
|
68e9883f59 | ||
|
|
95b018374a | ||
|
|
8b5c2385b0 |
@@ -101,7 +101,7 @@ clip = CLIP(
|
||||
unet = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
time_dim = 128,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
@@ -264,7 +264,7 @@ loss.backward()
|
||||
unet = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
time_dim = 128,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
@@ -276,7 +276,7 @@ decoder = Decoder(
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
loss = decoder(images)
|
||||
loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
@@ -319,12 +319,13 @@ Offer training wrappers
|
||||
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
|
||||
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
|
||||
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
|
||||
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
|
||||
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] add attention to unet - apply some personal tricks with efficient attention
|
||||
- [ ] figure out the big idea behind latent diffusion and what can be ported over
|
||||
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from einops.layers.torch import Rearrange
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
|
||||
from kornia.filters import filter2d
|
||||
from kornia.filters.gaussian import GaussianBlur2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
|
||||
@@ -118,14 +118,13 @@ class ChanRMSNorm(RMSNorm):
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs) + x
|
||||
return self.fn(x, **kwargs) + x
|
||||
|
||||
# mlp
|
||||
|
||||
@@ -162,14 +161,61 @@ class MLP(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x.float())
|
||||
|
||||
# relative positional bias for causal transformer
|
||||
|
||||
class RelPosBias(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
heads = 8,
|
||||
num_buckets = 32,
|
||||
max_distance = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.max_distance = max_distance
|
||||
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(
|
||||
relative_position,
|
||||
num_buckets = 32,
|
||||
max_distance = 128
|
||||
):
|
||||
n = -relative_position
|
||||
n = torch.max(n, torch.zeros_like(n))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = n < max_exact
|
||||
|
||||
val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
|
||||
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
||||
return torch.where(is_small, n, val_if_large)
|
||||
|
||||
def forward(self, i, j, *, device):
|
||||
q_pos = torch.arange(i, dtype = torch.long, device = device)
|
||||
k_pos = torch.arange(j, dtype = torch.long, device = device)
|
||||
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
|
||||
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
||||
values = self.relative_attention_bias(rp_bucket)
|
||||
return rearrange(values, 'i j h -> h i j')
|
||||
|
||||
# feedforward
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
class SwiGLU(nn.Module):
|
||||
""" used successfully in https://arxiv.org/abs/2204.0231 """
|
||||
def forward(self, x):
|
||||
x, gate = x.chunk(2, dim = -1)
|
||||
return x * F.silu(gate)
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
|
||||
""" post-activation norm https://arxiv.org/abs/2110.09456 """
|
||||
|
||||
inner_dim = int(mult * dim)
|
||||
return nn.Sequential(
|
||||
RMSNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias = False),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim, inner_dim * 2, bias = False),
|
||||
SwiGLU(),
|
||||
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim, bias = False)
|
||||
)
|
||||
@@ -200,7 +246,7 @@ class Attention(nn.Module):
|
||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
def forward(self, x, mask = None, attn_bias = None):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
|
||||
x = self.norm(x)
|
||||
@@ -217,6 +263,14 @@ class Attention(nn.Module):
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||
|
||||
# relative positional encoding (T5 style)
|
||||
|
||||
if exists(attn_bias):
|
||||
sim = sim + attn_bias
|
||||
|
||||
# masking
|
||||
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
@@ -229,8 +283,13 @@ class Attention(nn.Module):
|
||||
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
||||
sim = sim.masked_fill(causal_mask, max_neg_value)
|
||||
|
||||
# attention
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
@@ -251,7 +310,7 @@ class CausalTransformer(nn.Module):
|
||||
ff_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
# todo - bring in rotary embeddings or alibi
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
@@ -267,8 +326,12 @@ class CausalTransformer(nn.Module):
|
||||
x,
|
||||
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
||||
):
|
||||
n, device = x.shape[1], x.device
|
||||
|
||||
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, mask = mask) + x
|
||||
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
@@ -320,8 +383,8 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# but let's just do it right
|
||||
|
||||
if exists(mask):
|
||||
all_masked_out = mask.any(dim = -1)
|
||||
mask = torch.cat((mask, rearrange(all_masked_out, 'b -> b 1')), dim = 1)
|
||||
not_all_masked_out = mask.any(dim = -1)
|
||||
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
@@ -387,7 +450,7 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
@@ -589,10 +652,17 @@ class ConvNextBlock(nn.Module):
|
||||
super().__init__()
|
||||
need_projection = dim != dim_out
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.GELU(),
|
||||
nn.Linear(cond_dim, dim)
|
||||
) if exists(cond_dim) else None
|
||||
self.cross_attn = None
|
||||
|
||||
if exists(cond_dim):
|
||||
self.cross_attn = EinopsToAndFrom(
|
||||
'b c h w',
|
||||
'b (h w) c',
|
||||
CrossAttention(
|
||||
dim = dim,
|
||||
context_dim = cond_dim
|
||||
)
|
||||
)
|
||||
|
||||
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
||||
|
||||
@@ -609,43 +679,131 @@ class ConvNextBlock(nn.Module):
|
||||
def forward(self, x, cond = None):
|
||||
h = self.ds_conv(x)
|
||||
|
||||
if exists(self.mlp):
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
condition = self.mlp(cond)
|
||||
h = h + rearrange(condition, 'b c -> b c 1 1')
|
||||
h = self.cross_attn(h, context = cond) + h
|
||||
|
||||
h = self.net(h)
|
||||
|
||||
return h + self.res_conv(x)
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
context_dim = None,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
context_dim = default(context_dim, dim)
|
||||
|
||||
self.norm = RMSNorm(dim)
|
||||
self.norm_context = RMSNorm(context_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x, context, mask = None):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
|
||||
x = self.norm(x)
|
||||
context = self.norm_context(context)
|
||||
|
||||
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||
|
||||
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
|
||||
|
||||
# add null key / value for classifier free guidance in prior net
|
||||
|
||||
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
|
||||
|
||||
k = torch.cat((nk, k), dim = -2)
|
||||
v = torch.cat((nv, v), dim = -2)
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (1, 0), value = True)
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
image_embed_dim,
|
||||
time_dim = None,
|
||||
cond_dim = None,
|
||||
num_image_tokens = 4,
|
||||
out_dim = None,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
lowres_cond_upsample_mode = 'bilinear',
|
||||
blur_sigma = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# for eventual cascading diffusion
|
||||
|
||||
self.lowres_cond = lowres_cond
|
||||
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
||||
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma))
|
||||
|
||||
# determine dimensions
|
||||
|
||||
self.channels = channels
|
||||
|
||||
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
|
||||
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
time_dim = default(time_dim, dim)
|
||||
# time, image embeddings, and optional text encoding
|
||||
|
||||
cond_dim = default(cond_dim, dim)
|
||||
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * 4, dim)
|
||||
nn.Linear(dim * 4, cond_dim),
|
||||
Rearrange('b d -> b 1 d')
|
||||
)
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim))
|
||||
self.image_to_cond = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
cond_dim = time_dim + image_embed_dim
|
||||
self.text_to_cond = nn.LazyLinear(cond_dim)
|
||||
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
self.ups = nn.ModuleList([])
|
||||
@@ -655,7 +813,7 @@ class Unet(nn.Module):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0),
|
||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
@@ -663,7 +821,7 @@ class Unet(nn.Module):
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
@@ -701,56 +859,85 @@ class Unet(nn.Module):
|
||||
time,
|
||||
*,
|
||||
image_embed,
|
||||
lowres_cond_img = None,
|
||||
text_encodings = None,
|
||||
cond_drop_prob = 0.
|
||||
):
|
||||
batch_size, device = x.shape[0], x.device
|
||||
t = self.time_mlp(time)
|
||||
|
||||
# add low resolution conditioning, if present
|
||||
|
||||
assert not self.lowres_cond and not exists(lowres_cond_img), 'low resolution conditioning image must be present'
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img)
|
||||
|
||||
lowres_cond_img = F.interpolate(lowres_cond_img, size = x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
|
||||
# time conditioning
|
||||
|
||||
time_tokens = self.time_mlp(time)
|
||||
|
||||
# conditional dropout
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
||||
|
||||
# mask out image embedding depending on condition dropout
|
||||
# for classifier free guidance
|
||||
|
||||
image_embed = torch.where(
|
||||
rearrange(cond_prob_mask, 'b -> b 1'),
|
||||
image_embed,
|
||||
rearrange(self.null_image_embed, 'd -> 1 d')
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
|
||||
image_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
image_tokens,
|
||||
self.null_image_embed
|
||||
)
|
||||
|
||||
t = torch.cat((t, image_embed), dim = -1)
|
||||
# take care of text encodings (optional)
|
||||
|
||||
if exists(text_encodings):
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
text_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
text_tokens,
|
||||
self.null_text_embed
|
||||
)
|
||||
|
||||
# main conditioning tokens (c)
|
||||
|
||||
c = torch.cat((time_tokens, image_tokens), dim = -2)
|
||||
|
||||
# text and image conditioning tokens (mid_c)
|
||||
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
|
||||
|
||||
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
|
||||
|
||||
# go through the layers of the unet, down and up
|
||||
|
||||
hiddens = []
|
||||
|
||||
for convnext, convnext2, downsample in self.downs:
|
||||
x = convnext(x, t)
|
||||
x = convnext2(x, t)
|
||||
x = convnext(x, c)
|
||||
x = convnext2(x, c)
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, t)
|
||||
x = self.mid_block1(x, mid_c)
|
||||
x = self.mid_attn(x)
|
||||
x = self.mid_block2(x, t)
|
||||
x = self.mid_block2(x, mid_c)
|
||||
|
||||
for convnext, convnext2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = convnext(x, t)
|
||||
x = convnext2(x, t)
|
||||
x = convnext(x, c)
|
||||
x = convnext2(x, c)
|
||||
x = upsample(x)
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
filt = torch.Tensor([1, 2, 1])
|
||||
self.register_buffer('filt', filt)
|
||||
|
||||
def forward(self, x):
|
||||
filt = self.filt
|
||||
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
|
||||
return filter2d(x, filt, normalized = True)
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -775,7 +962,7 @@ class Decoder(nn.Module):
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
@@ -807,6 +994,10 @@ class Decoder(nn.Module):
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
def get_text_encodings(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
return text_encodings[:, 1:]
|
||||
|
||||
def get_image_embed(self, image):
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
@@ -834,8 +1025,8 @@ class Decoder(nn.Module):
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
|
||||
def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
|
||||
|
||||
if clip_denoised:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
@@ -844,31 +1035,32 @@ class Decoder(nn.Module):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
|
||||
def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, image_embed, cond_scale = 1.):
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
batch_size = image_embed.shape[0]
|
||||
image_size = self.image_size
|
||||
channels = self.channels
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
@@ -878,7 +1070,7 @@ class Decoder(nn.Module):
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def p_losses(self, x_start, t, *, image_embed, noise = None):
|
||||
def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||
@@ -887,6 +1079,7 @@ class Decoder(nn.Module):
|
||||
x_noisy,
|
||||
t,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_drop_prob = self.cond_drop_prob
|
||||
)
|
||||
|
||||
@@ -899,14 +1092,16 @@ class Decoder(nn.Module):
|
||||
|
||||
return loss
|
||||
|
||||
def forward(self, image):
|
||||
def forward(self, image, text = None):
|
||||
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
image_embed = self.get_image_embed(image)
|
||||
|
||||
loss = self.p_losses(image, times, image_embed = image_embed)
|
||||
image_embed = self.get_image_embed(image)
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
|
||||
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
|
||||
return loss
|
||||
|
||||
# main class
|
||||
@@ -922,11 +1117,12 @@ class DALLE2(nn.Module):
|
||||
super().__init__()
|
||||
assert isinstance(prior, DiffusionPrior)
|
||||
assert isinstance(decoder, Decoder)
|
||||
self.prior = prior.eval()
|
||||
self.decoder = decoder.eval()
|
||||
self.prior = prior
|
||||
self.decoder = decoder
|
||||
self.prior_num_samples = prior_num_samples
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def forward(
|
||||
self,
|
||||
text,
|
||||
|
||||
Reference in New Issue
Block a user