mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fee4fce6e | ||
|
|
a54e309269 | ||
|
|
c6bfd7fdc8 | ||
|
|
960a79857b | ||
|
|
7214df472d | ||
|
|
00ae50999b | ||
|
|
6cddefad26 | ||
|
|
0332eaa6ff |
109
README.md
109
README.md
@@ -109,7 +109,7 @@ unet = Unet(
|
||||
# decoder, which contains the unet and clip
|
||||
|
||||
decoder = Decoder(
|
||||
net = unet,
|
||||
unet = unet,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
@@ -182,7 +182,82 @@ loss.backward()
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)
|
||||
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.
|
||||
|
||||
This can easily be used within this framework as so
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder, CLIP
|
||||
|
||||
# trained clip from step 1
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 6,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 6,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
).cuda()
|
||||
|
||||
# 2 unets for the decoder (a la cascading DDPM)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
# decoder, which contains the unet(s) and clip
|
||||
|
||||
decoder = Decoder(
|
||||
clip = clip,
|
||||
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(4, 3, 512, 512).cuda()
|
||||
|
||||
# feed images into decoder, specifying which unet you want to train
|
||||
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
|
||||
|
||||
loss = decoder(images, unet_number = 1)
|
||||
loss.backward()
|
||||
|
||||
loss = decoder(images, unet_number = 2)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
|
||||
# chaining the unets from lowest resolution to highest resolution (thus cascading)
|
||||
|
||||
mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
|
||||
```
|
||||
|
||||
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
|
||||
|
||||
```python
|
||||
from dalle2_pytorch import DALLE2
|
||||
@@ -261,7 +336,7 @@ loss.backward()
|
||||
|
||||
# decoder (with unet)
|
||||
|
||||
unet = Unet(
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
@@ -269,15 +344,25 @@ unet = Unet(
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 16,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
decoder = Decoder(
|
||||
net = unet,
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
|
||||
@@ -291,11 +376,13 @@ images = dalle2(
|
||||
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
|
||||
)
|
||||
|
||||
# save your image
|
||||
# save your image (in this example, of size 256x256)
|
||||
```
|
||||
|
||||
Everything in this readme should run without error
|
||||
|
||||
You can also train the decoder on images of greater than the size (say 512x512) at which CLIP was trained (256x256). The images will be resized to CLIP image resolution for the image embeddings
|
||||
|
||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||
|
||||
## CLI Usage (work in progress)
|
||||
@@ -321,11 +408,11 @@ Offer training wrappers
|
||||
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
|
||||
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||
- [ ] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
|
||||
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
||||
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
||||
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007 (also in separate file as experimental) build out https://github.com/lucidrains/x-unet
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, add efficient attention (conditional on resolution), port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] train on a toy task, offer in colab
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -11,7 +12,7 @@ from einops.layers.torch import Rearrange
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
|
||||
from kornia.filters.gaussian import GaussianBlur2d
|
||||
from kornia.filters import gaussian_blur2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
|
||||
@@ -29,6 +30,9 @@ def default(val, d):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def eval_decorator(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
was_training = model.training
|
||||
@@ -64,6 +68,15 @@ def freeze_model_and_make_eval_(model):
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://github.com/assafshocher/ResizeRight
|
||||
shape = cast_tuple(image_size, 2)
|
||||
orig_image_size = t.shape[-2:]
|
||||
|
||||
if orig_image_size == shape:
|
||||
return t
|
||||
|
||||
return F.interpolate(t, size = shape, mode = mode)
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
def prob_mask_like(shape, prob, device):
|
||||
@@ -585,31 +598,6 @@ class DiffusionPrior(nn.Module):
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, text, num_samples_per_batch = 2):
|
||||
# in the paper, what they did was
|
||||
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
||||
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
||||
|
||||
batch_size = text.shape[0]
|
||||
image_embed_dim = self.image_embed_dim
|
||||
|
||||
text_cond = self.get_text_cond(text)
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
text_embeds = text_cond['text_embed']
|
||||
|
||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
|
||||
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
|
||||
top_sim_indices = text_image_sims.topk(k = 1).indices
|
||||
|
||||
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
|
||||
|
||||
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
||||
return rearrange(top_image_embeds, 'b 1 d -> b d')
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
@@ -643,6 +631,32 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample(self, text, num_samples_per_batch = 2):
|
||||
# in the paper, what they did was
|
||||
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
||||
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
||||
|
||||
batch_size = text.shape[0]
|
||||
image_embed_dim = self.image_embed_dim
|
||||
|
||||
text_cond = self.get_text_cond(text)
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
text_embeds = text_cond['text_embed']
|
||||
|
||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
|
||||
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
|
||||
top_sim_indices = text_image_sims.topk(k = 1).indices
|
||||
|
||||
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
|
||||
|
||||
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
||||
return rearrange(top_image_embeds, 'b 1 d -> b d')
|
||||
|
||||
def forward(self, text, image, *args, **kwargs):
|
||||
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||
@@ -797,15 +811,23 @@ class Unet(nn.Module):
|
||||
channels = 3,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
lowres_cond_upsample_mode = 'bilinear',
|
||||
blur_sigma = 0.1
|
||||
blur_sigma = 0.1,
|
||||
blur_kernel_size = 3,
|
||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||
):
|
||||
super().__init__()
|
||||
# save locals to take care of some hyperparameters for cascading DDPM
|
||||
|
||||
self._locals = locals()
|
||||
del self._locals['self']
|
||||
del self._locals['__class__']
|
||||
|
||||
# for eventual cascading diffusion
|
||||
|
||||
self.lowres_cond = lowres_cond
|
||||
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
||||
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma))
|
||||
self.lowres_blur_kernel_size = blur_kernel_size
|
||||
self.lowres_blur_sigma = blur_sigma
|
||||
|
||||
# determine dimensions
|
||||
|
||||
@@ -847,27 +869,30 @@ class Unet(nn.Module):
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_first = ind == 0
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
is_last = ind >= (num_resolutions - 2)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = cond_dim),
|
||||
ConvNextBlock(dim_in, dim_in, cond_dim = cond_dim),
|
||||
Upsample(dim_in) if not is_last else nn.Identity()
|
||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
out_dim = default(out_dim, channels)
|
||||
@@ -876,6 +901,15 @@ class Unet(nn.Module):
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
)
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
def force_lowres_cond(self, lowres_cond):
|
||||
if lowres_cond == self.lowres_cond:
|
||||
return self
|
||||
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
|
||||
return self.__class__(**updated_kwargs)
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
*args,
|
||||
@@ -898,20 +932,24 @@ class Unet(nn.Module):
|
||||
image_embed,
|
||||
lowres_cond_img = None,
|
||||
text_encodings = None,
|
||||
cond_drop_prob = 0.
|
||||
cond_drop_prob = 0.,
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None
|
||||
):
|
||||
batch_size, device = x.shape[0], x.device
|
||||
|
||||
# add low resolution conditioning, if present
|
||||
|
||||
assert not self.lowres_cond and not exists(lowres_cond_img), 'low resolution conditioning image must be present'
|
||||
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img)
|
||||
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
|
||||
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
|
||||
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
lowres_cond_img = F.interpolate(lowres_cond_img, size = x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
||||
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
|
||||
# time conditioning
|
||||
@@ -964,7 +1002,10 @@ class Unet(nn.Module):
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, mid_c)
|
||||
x = self.mid_attn(x)
|
||||
|
||||
if exists(self.mid_attn):
|
||||
x = self.mid_attn(x)
|
||||
|
||||
x = self.mid_block2(x, mid_c)
|
||||
|
||||
for convnext, convnext2, upsample in self.ups:
|
||||
@@ -978,22 +1019,42 @@ class Unet(nn.Module):
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
net,
|
||||
unet,
|
||||
*,
|
||||
clip,
|
||||
timesteps=1000,
|
||||
cond_drop_prob=0.2,
|
||||
loss_type="l1",
|
||||
beta_schedule="cosine",
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1',
|
||||
beta_schedule = 'cosine',
|
||||
image_sizes = None # for cascading ddpm, image size at each stage
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
|
||||
self.net = net
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
self.image_size = clip.image_size
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
self.unets = nn.ModuleList([])
|
||||
for ind, one_unet in enumerate(cast_tuple(unet)):
|
||||
is_first = ind == 0
|
||||
one_unet = one_unet.force_lowres_cond(not is_first)
|
||||
self.unets.append(one_unet)
|
||||
|
||||
# unet image sizes
|
||||
|
||||
image_sizes = default(image_sizes, (clip.image_size,))
|
||||
image_sizes = tuple(sorted(set(image_sizes)))
|
||||
|
||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||
self.image_sizes = image_sizes
|
||||
|
||||
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
||||
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
@@ -1048,6 +1109,7 @@ class Decoder(nn.Module):
|
||||
return text_encodings[:, 1:]
|
||||
|
||||
def get_image_embed(self, image):
|
||||
image = resize_image_to(image, self.clip_image_size)
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
@@ -1074,8 +1136,9 @@ class Decoder(nn.Module):
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.):
|
||||
pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
|
||||
|
||||
if clip_denoised:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
@@ -1084,33 +1147,25 @@ class Decoder(nn.Module):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
|
||||
def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device)
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
batch_size = image_embed.shape[0]
|
||||
image_size = self.image_size
|
||||
channels = self.channels
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
@@ -1119,16 +1174,17 @@ class Decoder(nn.Module):
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
|
||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||
|
||||
x_recon = self.net(
|
||||
x_recon = unet(
|
||||
x_noisy,
|
||||
t,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
cond_drop_prob = self.cond_drop_prob
|
||||
)
|
||||
|
||||
@@ -1143,17 +1199,45 @@ class Decoder(nn.Module):
|
||||
|
||||
return loss
|
||||
|
||||
def forward(self, image, text = None):
|
||||
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
batch_size = image_embed.shape[0]
|
||||
channels = self.channels
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
|
||||
img = None
|
||||
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
|
||||
shape = (batch_size, channels, image_size, image_size)
|
||||
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
||||
|
||||
return img
|
||||
|
||||
def forward(self, image, text = None, image_embed = None, unet_number = None):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
assert 1 <= unet_number <= len(self.unets)
|
||||
|
||||
index = unet_number - 1
|
||||
unet = self.unets[index]
|
||||
target_image_size = self.image_sizes[index]
|
||||
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
|
||||
check_shape(image, 'b c h w', c = self.channels)
|
||||
assert h >= target_image_size and w >= target_image_size
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
|
||||
image_embed = self.get_image_embed(image)
|
||||
if not exists(image_embed):
|
||||
image_embed = self.get_image_embed(image)
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
|
||||
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
|
||||
return loss
|
||||
lowres_cond_img = image if index > 0 else None
|
||||
ddpm_image = resize_image_to(image, target_image_size)
|
||||
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
Reference in New Issue
Block a user