mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aec5575d09 | ||
|
|
9773f10d6c | ||
|
|
a6bf8ddef6 | ||
|
|
86e692d24f | ||
|
|
97b751209f | ||
|
|
74103fd8d6 |
@@ -10,7 +10,7 @@ The main novelty seems to be an extra layer of indirection with the prior networ
|
||||
|
||||
This model is SOTA for text-to-image for now.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
|
||||
|
||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
|
||||
@@ -822,6 +822,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] bring in tools to train vqgan-vae
|
||||
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
|
||||
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
|
||||
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo)
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
@@ -833,9 +834,9 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [ ] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] make sure resnet | convnext block hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -16,10 +16,13 @@ from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
|
||||
from kornia.filters import gaussian_blur2d
|
||||
import kornia.augmentation as K
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||
|
||||
from resize_right import resize
|
||||
|
||||
# use x-clip
|
||||
|
||||
from x_clip import CLIP
|
||||
@@ -85,14 +88,14 @@ def freeze_model_and_make_eval_(model):
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://github.com/assafshocher/ResizeRight
|
||||
shape = cast_tuple(image_size, 2)
|
||||
orig_image_size = t.shape[-2:]
|
||||
def resize_image_to(image, target_image_size):
|
||||
orig_image_size = image.shape[-1]
|
||||
|
||||
if orig_image_size == shape:
|
||||
return t
|
||||
if orig_image_size == target_image_size:
|
||||
return image
|
||||
|
||||
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
|
||||
scale_factors = target_image_size / orig_image_size
|
||||
return resize(image, scale_factors = scale_factors)
|
||||
|
||||
# image normalization functions
|
||||
# ddpms expect images to be in the range of -1 to 1
|
||||
@@ -804,7 +807,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised)
|
||||
@@ -813,7 +816,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def p_sample_loop(self, shape, text_cond):
|
||||
device = self.betas.device
|
||||
|
||||
@@ -841,7 +844,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
@eval_decorator
|
||||
def sample(self, text, num_samples_per_batch = 2):
|
||||
# in the paper, what they did was
|
||||
@@ -1476,13 +1479,11 @@ class Unet(nn.Module):
|
||||
class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cond_upsample_mode = 'bilinear',
|
||||
downsample_first = True,
|
||||
blur_sigma = 0.1,
|
||||
blur_kernel_size = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.cond_upsample_mode = cond_upsample_mode
|
||||
self.downsample_first = downsample_first
|
||||
self.blur_sigma = blur_sigma
|
||||
self.blur_kernel_size = blur_kernel_size
|
||||
@@ -1496,10 +1497,8 @@ class LowresConditioner(nn.Module):
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None
|
||||
):
|
||||
target_image_size = cast_tuple(target_image_size, 2)
|
||||
|
||||
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, mode = self.cond_upsample_mode)
|
||||
cond_fmap = resize_image_to(cond_fmap, downsample_image_size)
|
||||
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
@@ -1507,7 +1506,7 @@ class LowresConditioner(nn.Module):
|
||||
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size)
|
||||
|
||||
return cond_fmap
|
||||
|
||||
@@ -1526,7 +1525,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
predict_x_start = False,
|
||||
predict_x_start_for_latent_diffusion = False,
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
@@ -1588,6 +1587,10 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.image_sizes = image_sizes
|
||||
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
|
||||
|
||||
# random crop sizes (for super-resoluting unets at the end of cascade?)
|
||||
|
||||
self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))
|
||||
|
||||
# predict x0 config
|
||||
|
||||
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||
@@ -1598,7 +1601,6 @@ class Decoder(BaseGaussianDiffusion):
|
||||
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
||||
|
||||
self.to_lowres_cond = LowresConditioner(
|
||||
cond_upsample_mode = lowres_cond_upsample_mode,
|
||||
downsample_first = lowres_downsample_first,
|
||||
blur_sigma = blur_sigma,
|
||||
blur_kernel_size = blur_kernel_size,
|
||||
@@ -1633,12 +1635,6 @@ class Decoder(BaseGaussianDiffusion):
|
||||
yield
|
||||
unet.cpu()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
return image_embed
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
@@ -1653,7 +1649,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||
@@ -1662,7 +1658,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
@@ -1706,7 +1702,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
@eval_decorator
|
||||
def sample(
|
||||
self,
|
||||
@@ -1777,10 +1773,10 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
vae = self.vaes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
|
||||
vae = self.vaes[unet_index]
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
random_crop_size = self.random_crop_sizes[unet_index]
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
|
||||
check_shape(image, 'b c h w', c = self.channels)
|
||||
@@ -1801,6 +1797,14 @@ class Decoder(BaseGaussianDiffusion):
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
image = resize_image_to(image, target_image_size)
|
||||
|
||||
if exists(random_crop_size):
|
||||
aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
|
||||
|
||||
# make sure low res conditioner and image both get augmented the same way
|
||||
# detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
|
||||
image = aug(image)
|
||||
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
|
||||
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
image = vae.encode(image)
|
||||
@@ -1831,7 +1835,7 @@ class DALLE2(nn.Module):
|
||||
|
||||
self.to_pil = T.ToPILImage()
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
@eval_decorator
|
||||
def forward(
|
||||
self,
|
||||
|
||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.96',
|
||||
version = '0.0.100',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -29,6 +29,7 @@ setup(
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'pillow',
|
||||
'resize-right>=0.0.2',
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
'tqdm',
|
||||
|
||||
Reference in New Issue
Block a user