Compare commits

...

7 Commits

3 changed files with 319 additions and 194 deletions

124
README.md
View File

@@ -348,7 +348,8 @@ decoder = Decoder(
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
cond_drop_prob = 0.2,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
for unet_number in (1, 2):
@@ -376,6 +377,124 @@ You can also train the decoder on images of greater than the size (say 512x512)
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## Training on Preprocessed CLIP Embeddings
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
Working example below
```python
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
# get trained CLIP from step one
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8,
).cuda()
# setup prior network, which contains an autoregressive transformer
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
# diffusion prior network, which contains the CLIP and network (with transformer) above
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2,
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = diffusion_prior.get_image_embed(images)
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
# feed text and images into diffusion prior network
loss = diffusion_prior(
text_embed = clip_text_embeds,
image_embed = clip_image_embeds
)
loss.backward()
# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings
```
You can also completely go `CLIP`-less, in which case you will need to pass in the `image_embed_dim` into the `DiffusionPrior` on initialization
```python
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior
# setup prior network, which contains an autoregressive transformer
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
# diffusion prior network, which contains the CLIP and network (with transformer) above
diffusion_prior = DiffusionPrior(
net = prior_network,
image_embed_dim = 512, # this needs to be set
timesteps = 100,
cond_drop_prob = 0.2,
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = torch.randn(4, 512).cuda()
clip_text_embeds = torch.randn(4, 512).cuda()
# feed text and images into diffusion prior network
loss = diffusion_prior(
text_embed = clip_text_embeds,
image_embed = clip_image_embeds
)
loss.backward()
# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings
```
## Experimental
### DALL-E2 with Latent Diffusion
@@ -524,7 +643,8 @@ Once built, images will be saved to the same directory the command is invoked
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
- [ ] spend one day cleaning up tech debt in decoder
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs

View File

@@ -84,7 +84,7 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
if orig_image_size == shape:
return t
return F.interpolate(t, size = shape, mode = mode)
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
# classifier free guidance functions
@@ -143,6 +143,92 @@ def sigmoid_beta_schedule(timesteps):
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
class BaseGaussianDiffusion(nn.Module):
def __init__(self, *, beta_schedule, timesteps, loss_type):
super().__init__()
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "quadratic":
betas = quadratic_beta_schedule(timesteps)
elif beta_schedule == "jsd":
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
elif beta_schedule == "sigmoid":
betas = sigmoid_beta_schedule(timesteps)
else:
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def sample(self, *args, **kwargs):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
# diffusion prior
class LayerNorm(nn.Module):
@@ -421,25 +507,41 @@ class DiffusionPriorNetwork(nn.Module):
image_embed,
diffusion_timesteps,
*,
text_encodings,
text_embed,
text_encodings = None,
mask = None,
cond_drop_prob = 0.2
):
batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
# in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
# make text encodings optional
# although the paper seems to suggest it is present <--
if not exists(text_encodings):
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
if not exists(mask):
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
# classifier free guidance
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
mask &= cond_prob_mask
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
mask = torch.cat((mask, cond_prob_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if exists(mask):
not_all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
@@ -455,16 +557,6 @@ class DiffusionPriorNetwork(nn.Module):
learned_queries
), dim = -2)
# mask if it doesn't exist
if not exists(mask):
mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool)
# classifier free guidance
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
mask &= rearrange(cond_prob_mask, 'b -> b 1')
# attend
tokens = self.causal_transformer(tokens, mask = mask)
@@ -475,81 +567,50 @@ class DiffusionPriorNetwork(nn.Module):
return pred_image_embed
class DiffusionPrior(nn.Module):
class DiffusionPrior(BaseGaussianDiffusion):
def __init__(
self,
net,
*,
clip,
clip = None,
image_embed_dim = None,
image_size = None,
image_channels = 3,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = "l1",
predict_x_start = True,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
):
super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
super().__init__(
beta_schedule = beta_schedule,
timesteps = timesteps,
loss_type = loss_type
)
if exists(clip):
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
else:
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
self.clip = None
self.net = net
self.image_embed_dim = clip.dim_latent
self.channels = clip.image_channels
self.image_size = clip.image_size
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
self.channels = default(image_channels, lambda: clip.image_channels)
self.cond_drop_prob = cond_drop_prob
self.condition_on_text_encodings = condition_on_text_encodings
self.predict_x_start = predict_x_start
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "quadratic":
betas = quadratic_beta_schedule(timesteps)
elif beta_schedule == "jsd":
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
elif beta_schedule == "sigmoid":
betas = sigmoid_beta_schedule(timesteps)
else:
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
@torch.no_grad()
def get_image_embed(self, image):
assert exists(self.clip)
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
@@ -557,33 +618,18 @@ class DiffusionPrior(nn.Module):
@torch.no_grad()
def get_text_cond(self, text):
assert exists(self.clip)
text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
text_embed = l2norm(text_embed)
if not self.condition_on_text_encodings:
return dict(text_embed = text_embed)
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond)
@@ -620,14 +666,6 @@ class DiffusionPrior(nn.Module):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, image_embed, t, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
@@ -679,13 +717,41 @@ class DiffusionPrior(nn.Module):
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, 'b 1 d -> b d')
def forward(self, text, image, *args, **kwargs):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
def forward(
self,
text = None,
image = None,
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
image_embed = None,
text_encodings = None, # as well as CLIP text encodings
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
*args,
**kwargs
):
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
text_cond = self.get_text_cond(text)
if exists(image):
image_embed = self.get_image_embed(image)
# calculate text conditionings, based on what is passed in
if exists(text):
text_cond = self.get_text_cond(text)
else:
text_cond = dict(
text_embed = text_embed,
text_encodings = text_encodings,
mask = text_mask
)
# timestep conditioning from ddpm
batch, device = image_embed.shape[0], image_embed.device
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
# calculate forward loss
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
return loss
@@ -853,6 +919,7 @@ class Unet(nn.Module):
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
max_text_len = 256,
cond_on_image_embeds = False,
):
super().__init__()
@@ -903,7 +970,7 @@ class Unet(nn.Module):
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
# attention related params
@@ -1031,7 +1098,7 @@ class Unet(nn.Module):
text_tokens = torch.where(
cond_prob_mask,
text_tokens,
self.null_text_embed
self.null_text_embed[:, :text_tokens.shape[1]]
)
# main conditioning tokens (c)
@@ -1111,7 +1178,7 @@ class LowresConditioner(nn.Module):
return cond_fmap
class Decoder(nn.Module):
class Decoder(BaseGaussianDiffusion):
def __init__(
self,
unet,
@@ -1129,14 +1196,22 @@ class Decoder(nn.Module):
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
):
super().__init__()
super().__init__(
beta_schedule = beta_schedule,
timesteps = timesteps,
loss_type = loss_type
)
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
self.condition_on_text_encodings = condition_on_text_encodings
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
@@ -1192,55 +1267,6 @@ class Decoder(nn.Module):
self.cond_drop_prob = cond_drop_prob
# noise schedule
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "quadratic":
betas = quadratic_beta_schedule(timesteps)
elif beta_schedule == "jsd":
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
elif beta_schedule == "sigmoid":
betas = sigmoid_beta_schedule(timesteps)
else:
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
@@ -1273,27 +1299,6 @@ class Decoder(nn.Module):
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed)
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
@@ -1338,14 +1343,6 @@ class Decoder(nn.Module):
return img
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -1380,6 +1377,8 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) else None
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
img = None
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
@@ -1440,6 +1439,8 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)
@@ -1467,7 +1468,9 @@ class DALLE2(nn.Module):
assert isinstance(decoder, Decoder)
self.prior = prior
self.decoder = decoder
self.prior_num_samples = prior_num_samples
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
@torch.no_grad()
@eval_decorator
@@ -1484,7 +1487,9 @@ class DALLE2(nn.Module):
text = tokenizer.tokenize(text).to(device)
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
if one_text:
return images[0]

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.46',
version = '0.0.51',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',