mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-19 15:34:21 +01:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de0296106b | ||
|
|
eafb136214 | ||
|
|
bfbcc283a3 | ||
|
|
c30544b73a | ||
|
|
bdf5e9c009 | ||
|
|
9878be760b |
55
README.md
55
README.md
@@ -348,7 +348,8 @@ decoder = Decoder(
|
|||||||
image_sizes = (128, 256),
|
image_sizes = (128, 256),
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 100,
|
timesteps = 100,
|
||||||
cond_drop_prob = 0.2
|
cond_drop_prob = 0.2,
|
||||||
|
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
for unet_number in (1, 2):
|
for unet_number in (1, 2):
|
||||||
@@ -445,6 +446,55 @@ loss.backward()
|
|||||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can also completely go `CLIP`-less, in which case you will need to pass in the `image_embed_dim` into the `DiffusionPrior` on initialization
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior
|
||||||
|
|
||||||
|
# setup prior network, which contains an autoregressive transformer
|
||||||
|
|
||||||
|
prior_network = DiffusionPriorNetwork(
|
||||||
|
dim = 512,
|
||||||
|
depth = 6,
|
||||||
|
dim_head = 64,
|
||||||
|
heads = 8
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# diffusion prior network, which contains the CLIP and network (with transformer) above
|
||||||
|
|
||||||
|
diffusion_prior = DiffusionPrior(
|
||||||
|
net = prior_network,
|
||||||
|
image_embed_dim = 512, # this needs to be set
|
||||||
|
timesteps = 100,
|
||||||
|
cond_drop_prob = 0.2,
|
||||||
|
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# mock data
|
||||||
|
|
||||||
|
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||||
|
images = torch.randn(4, 3, 256, 256).cuda()
|
||||||
|
|
||||||
|
# precompute the text and image embeddings
|
||||||
|
# here using the diffusion prior class, but could be done with CLIP alone
|
||||||
|
|
||||||
|
clip_image_embeds = torch.randn(4, 512).cuda()
|
||||||
|
clip_text_embeds = torch.randn(4, 512).cuda()
|
||||||
|
|
||||||
|
# feed text and images into diffusion prior network
|
||||||
|
|
||||||
|
loss = diffusion_prior(
|
||||||
|
text_embed = clip_text_embeds,
|
||||||
|
image_embed = clip_image_embeds
|
||||||
|
)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# do the above for many many many steps
|
||||||
|
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||||
|
```
|
||||||
|
|
||||||
## Experimental
|
## Experimental
|
||||||
|
|
||||||
### DALL-E2 with Latent Diffusion
|
### DALL-E2 with Latent Diffusion
|
||||||
@@ -593,7 +643,8 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||||
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||||
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
||||||
- [ ] spend one day cleaning up tech debt in decoder
|
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||||
|
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
|
|||||||
if orig_image_size == shape:
|
if orig_image_size == shape:
|
||||||
return t
|
return t
|
||||||
|
|
||||||
return F.interpolate(t, size = shape, mode = mode)
|
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
|
||||||
|
|
||||||
# classifier free guidance functions
|
# classifier free guidance functions
|
||||||
|
|
||||||
@@ -143,6 +143,92 @@ def sigmoid_beta_schedule(timesteps):
|
|||||||
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGaussianDiffusion(nn.Module):
|
||||||
|
def __init__(self, *, beta_schedule, timesteps, loss_type):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if beta_schedule == "cosine":
|
||||||
|
betas = cosine_beta_schedule(timesteps)
|
||||||
|
elif beta_schedule == "linear":
|
||||||
|
betas = linear_beta_schedule(timesteps)
|
||||||
|
elif beta_schedule == "quadratic":
|
||||||
|
betas = quadratic_beta_schedule(timesteps)
|
||||||
|
elif beta_schedule == "jsd":
|
||||||
|
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
||||||
|
elif beta_schedule == "sigmoid":
|
||||||
|
betas = sigmoid_beta_schedule(timesteps)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
alphas = 1. - betas
|
||||||
|
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
||||||
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||||
|
|
||||||
|
timesteps, = betas.shape
|
||||||
|
self.num_timesteps = int(timesteps)
|
||||||
|
self.loss_type = loss_type
|
||||||
|
|
||||||
|
self.register_buffer('betas', betas)
|
||||||
|
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||||
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||||
|
|
||||||
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||||
|
|
||||||
|
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||||
|
|
||||||
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||||
|
|
||||||
|
self.register_buffer('posterior_variance', posterior_variance)
|
||||||
|
|
||||||
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||||
|
|
||||||
|
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
||||||
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||||
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||||
|
|
||||||
|
def q_mean_variance(self, x_start, t):
|
||||||
|
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||||
|
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
||||||
|
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||||
|
return mean, variance, log_variance
|
||||||
|
|
||||||
|
def q_posterior(self, x_start, x_t, t):
|
||||||
|
posterior_mean = (
|
||||||
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||||
|
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||||
|
)
|
||||||
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||||
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||||
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
|
def q_sample(self, x_start, t, noise=None):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
|
return (
|
||||||
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||||
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict_start_from_noise(self, x_t, t, noise):
|
||||||
|
return (
|
||||||
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
||||||
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
@@ -481,12 +567,15 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
return pred_image_embed
|
return pred_image_embed
|
||||||
|
|
||||||
class DiffusionPrior(nn.Module):
|
class DiffusionPrior(BaseGaussianDiffusion):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
net,
|
net,
|
||||||
*,
|
*,
|
||||||
clip,
|
clip = None,
|
||||||
|
image_embed_dim = None,
|
||||||
|
image_size = None,
|
||||||
|
image_channels = 3,
|
||||||
timesteps = 1000,
|
timesteps = 1000,
|
||||||
cond_drop_prob = 0.2,
|
cond_drop_prob = 0.2,
|
||||||
loss_type = "l1",
|
loss_type = "l1",
|
||||||
@@ -494,15 +583,23 @@ class DiffusionPrior(nn.Module):
|
|||||||
beta_schedule = "cosine",
|
beta_schedule = "cosine",
|
||||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
assert isinstance(clip, CLIP)
|
beta_schedule = beta_schedule,
|
||||||
freeze_model_and_make_eval_(clip)
|
timesteps = timesteps,
|
||||||
self.clip = clip
|
loss_type = loss_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if exists(clip):
|
||||||
|
assert isinstance(clip, CLIP)
|
||||||
|
freeze_model_and_make_eval_(clip)
|
||||||
|
self.clip = clip
|
||||||
|
else:
|
||||||
|
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
|
||||||
|
self.clip = None
|
||||||
|
|
||||||
self.net = net
|
self.net = net
|
||||||
self.image_embed_dim = clip.dim_latent
|
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
|
||||||
self.channels = clip.image_channels
|
self.channels = default(image_channels, lambda: clip.image_channels)
|
||||||
self.image_size = clip.image_size
|
|
||||||
|
|
||||||
self.cond_drop_prob = cond_drop_prob
|
self.cond_drop_prob = cond_drop_prob
|
||||||
self.condition_on_text_encodings = condition_on_text_encodings
|
self.condition_on_text_encodings = condition_on_text_encodings
|
||||||
@@ -510,55 +607,10 @@ class DiffusionPrior(nn.Module):
|
|||||||
self.predict_x_start = predict_x_start
|
self.predict_x_start = predict_x_start
|
||||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||||
|
|
||||||
if beta_schedule == "cosine":
|
|
||||||
betas = cosine_beta_schedule(timesteps)
|
|
||||||
elif beta_schedule == "linear":
|
|
||||||
betas = linear_beta_schedule(timesteps)
|
|
||||||
elif beta_schedule == "quadratic":
|
|
||||||
betas = quadratic_beta_schedule(timesteps)
|
|
||||||
elif beta_schedule == "jsd":
|
|
||||||
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
|
||||||
elif beta_schedule == "sigmoid":
|
|
||||||
betas = sigmoid_beta_schedule(timesteps)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
alphas = 1. - betas
|
|
||||||
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
|
||||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
|
||||||
|
|
||||||
timesteps, = betas.shape
|
|
||||||
self.num_timesteps = int(timesteps)
|
|
||||||
self.loss_type = loss_type
|
|
||||||
|
|
||||||
self.register_buffer('betas', betas)
|
|
||||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
|
||||||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
||||||
|
|
||||||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
|
||||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
|
||||||
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
|
||||||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
|
||||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
|
||||||
|
|
||||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
||||||
|
|
||||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
|
||||||
|
|
||||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
||||||
|
|
||||||
self.register_buffer('posterior_variance', posterior_variance)
|
|
||||||
|
|
||||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
||||||
|
|
||||||
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
|
||||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
||||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_image_embed(self, image):
|
def get_image_embed(self, image):
|
||||||
|
assert exists(self.clip)
|
||||||
|
|
||||||
image_encoding = self.clip.visual_transformer(image)
|
image_encoding = self.clip.visual_transformer(image)
|
||||||
image_cls = image_encoding[:, 0]
|
image_cls = image_encoding[:, 0]
|
||||||
image_embed = self.clip.to_visual_latent(image_cls)
|
image_embed = self.clip.to_visual_latent(image_cls)
|
||||||
@@ -566,6 +618,8 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_text_cond(self, text):
|
def get_text_cond(self, text):
|
||||||
|
assert exists(self.clip)
|
||||||
|
|
||||||
text_encodings = self.clip.text_transformer(text)
|
text_encodings = self.clip.text_transformer(text)
|
||||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||||
text_embed = self.clip.to_text_latent(text_cls)
|
text_embed = self.clip.to_text_latent(text_cls)
|
||||||
@@ -576,27 +630,6 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
||||||
|
|
||||||
def q_mean_variance(self, x_start, t):
|
|
||||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
||||||
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
|
||||||
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
|
||||||
return mean, variance, log_variance
|
|
||||||
|
|
||||||
def predict_start_from_noise(self, x_t, t, noise):
|
|
||||||
return (
|
|
||||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
|
||||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
|
||||||
)
|
|
||||||
|
|
||||||
def q_posterior(self, x_start, x_t, t):
|
|
||||||
posterior_mean = (
|
|
||||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
||||||
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
|
||||||
)
|
|
||||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
|
||||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
||||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||||
pred = self.net(x, t, **text_cond)
|
pred = self.net(x, t, **text_cond)
|
||||||
|
|
||||||
@@ -633,14 +666,6 @@ class DiffusionPrior(nn.Module):
|
|||||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
||||||
|
|
||||||
return (
|
|
||||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
||||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
||||||
)
|
|
||||||
|
|
||||||
def p_losses(self, image_embed, t, text_cond, noise = None):
|
def p_losses(self, image_embed, t, text_cond, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||||
|
|
||||||
@@ -881,6 +906,7 @@ class Unet(nn.Module):
|
|||||||
dim,
|
dim,
|
||||||
*,
|
*,
|
||||||
image_embed_dim,
|
image_embed_dim,
|
||||||
|
text_embed_dim = None,
|
||||||
cond_dim = None,
|
cond_dim = None,
|
||||||
num_image_tokens = 4,
|
num_image_tokens = 4,
|
||||||
num_time_tokens = 2,
|
num_time_tokens = 2,
|
||||||
@@ -894,6 +920,7 @@ class Unet(nn.Module):
|
|||||||
sparse_attn_window = 8, # window size for sparse attention
|
sparse_attn_window = 8, # window size for sparse attention
|
||||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||||
cond_on_text_encodings = False,
|
cond_on_text_encodings = False,
|
||||||
|
max_text_len = 256,
|
||||||
cond_on_image_embeds = False,
|
cond_on_image_embeds = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -933,7 +960,7 @@ class Unet(nn.Module):
|
|||||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||||
) if image_embed_dim != cond_dim else nn.Identity()
|
) if image_embed_dim != cond_dim else nn.Identity()
|
||||||
|
|
||||||
self.text_to_cond = nn.LazyLinear(cond_dim)
|
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
|
||||||
|
|
||||||
# finer control over whether to condition on image embeddings and text encodings
|
# finer control over whether to condition on image embeddings and text encodings
|
||||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||||
@@ -944,7 +971,7 @@ class Unet(nn.Module):
|
|||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||||
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||||
|
|
||||||
# attention related params
|
# attention related params
|
||||||
|
|
||||||
@@ -1072,7 +1099,7 @@ class Unet(nn.Module):
|
|||||||
text_tokens = torch.where(
|
text_tokens = torch.where(
|
||||||
cond_prob_mask,
|
cond_prob_mask,
|
||||||
text_tokens,
|
text_tokens,
|
||||||
self.null_text_embed
|
self.null_text_embed[:, :text_tokens.shape[1]]
|
||||||
)
|
)
|
||||||
|
|
||||||
# main conditioning tokens (c)
|
# main conditioning tokens (c)
|
||||||
@@ -1152,7 +1179,7 @@ class LowresConditioner(nn.Module):
|
|||||||
|
|
||||||
return cond_fmap
|
return cond_fmap
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(BaseGaussianDiffusion):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
unet,
|
unet,
|
||||||
@@ -1170,14 +1197,22 @@ class Decoder(nn.Module):
|
|||||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||||
|
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
|
beta_schedule = beta_schedule,
|
||||||
|
timesteps = timesteps,
|
||||||
|
loss_type = loss_type
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
freeze_model_and_make_eval_(clip)
|
freeze_model_and_make_eval_(clip)
|
||||||
self.clip = clip
|
self.clip = clip
|
||||||
self.clip_image_size = clip.image_size
|
self.clip_image_size = clip.image_size
|
||||||
self.channels = clip.image_channels
|
self.channels = clip.image_channels
|
||||||
|
|
||||||
|
self.condition_on_text_encodings = condition_on_text_encodings
|
||||||
|
|
||||||
# automatically take care of ensuring that first unet is unconditional
|
# automatically take care of ensuring that first unet is unconditional
|
||||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||||
|
|
||||||
@@ -1233,55 +1268,6 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self.cond_drop_prob = cond_drop_prob
|
self.cond_drop_prob = cond_drop_prob
|
||||||
|
|
||||||
# noise schedule
|
|
||||||
|
|
||||||
if beta_schedule == "cosine":
|
|
||||||
betas = cosine_beta_schedule(timesteps)
|
|
||||||
elif beta_schedule == "linear":
|
|
||||||
betas = linear_beta_schedule(timesteps)
|
|
||||||
elif beta_schedule == "quadratic":
|
|
||||||
betas = quadratic_beta_schedule(timesteps)
|
|
||||||
elif beta_schedule == "jsd":
|
|
||||||
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
|
||||||
elif beta_schedule == "sigmoid":
|
|
||||||
betas = sigmoid_beta_schedule(timesteps)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
alphas = 1. - betas
|
|
||||||
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
|
||||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
|
||||||
|
|
||||||
timesteps, = betas.shape
|
|
||||||
self.num_timesteps = int(timesteps)
|
|
||||||
self.loss_type = loss_type
|
|
||||||
|
|
||||||
self.register_buffer('betas', betas)
|
|
||||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
|
||||||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
||||||
|
|
||||||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
|
||||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
|
||||||
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
|
||||||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
|
||||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
|
||||||
|
|
||||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
||||||
|
|
||||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
|
||||||
|
|
||||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
||||||
|
|
||||||
self.register_buffer('posterior_variance', posterior_variance)
|
|
||||||
|
|
||||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
||||||
|
|
||||||
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
|
||||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
||||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
||||||
|
|
||||||
def get_unet(self, unet_number):
|
def get_unet(self, unet_number):
|
||||||
assert 0 < unet_number <= len(self.unets)
|
assert 0 < unet_number <= len(self.unets)
|
||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
@@ -1314,27 +1300,6 @@ class Decoder(nn.Module):
|
|||||||
image_embed = self.clip.to_visual_latent(image_cls)
|
image_embed = self.clip.to_visual_latent(image_cls)
|
||||||
return l2norm(image_embed)
|
return l2norm(image_embed)
|
||||||
|
|
||||||
def q_mean_variance(self, x_start, t):
|
|
||||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
||||||
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
|
||||||
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
|
||||||
return mean, variance, log_variance
|
|
||||||
|
|
||||||
def predict_start_from_noise(self, x_t, t, noise):
|
|
||||||
return (
|
|
||||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
|
||||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
|
||||||
)
|
|
||||||
|
|
||||||
def q_posterior(self, x_start, x_t, t):
|
|
||||||
posterior_mean = (
|
|
||||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
||||||
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
|
||||||
)
|
|
||||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
|
||||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
||||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
||||||
|
|
||||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||||
|
|
||||||
@@ -1379,14 +1344,6 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
||||||
|
|
||||||
return (
|
|
||||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
||||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
||||||
)
|
|
||||||
|
|
||||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
@@ -1421,6 +1378,8 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||||
|
|
||||||
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||||
|
|
||||||
img = None
|
img = None
|
||||||
|
|
||||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||||
@@ -1481,6 +1440,8 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||||
|
|
||||||
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||||
|
|
||||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||||
image = resize_image_to(image, target_image_size)
|
image = resize_image_to(image, target_image_size)
|
||||||
|
|
||||||
@@ -1508,7 +1469,9 @@ class DALLE2(nn.Module):
|
|||||||
assert isinstance(decoder, Decoder)
|
assert isinstance(decoder, Decoder)
|
||||||
self.prior = prior
|
self.prior = prior
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
self.prior_num_samples = prior_num_samples
|
self.prior_num_samples = prior_num_samples
|
||||||
|
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@eval_decorator
|
@eval_decorator
|
||||||
@@ -1525,7 +1488,9 @@ class DALLE2(nn.Module):
|
|||||||
text = tokenizer.tokenize(text).to(device)
|
text = tokenizer.tokenize(text).to(device)
|
||||||
|
|
||||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||||
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
|
||||||
|
text_cond = text if self.decoder_need_text_cond else None
|
||||||
|
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
||||||
|
|
||||||
if one_text:
|
if one_text:
|
||||||
return images[0]
|
return images[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user