mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1c1e508369 | ||
|
|
f19c99ecb0 | ||
|
|
721a444686 | ||
|
|
63450b466d | ||
|
|
20e7eb5a9b | ||
|
|
e2f9615afa | ||
|
|
0d1c07c803 | ||
|
|
a389f81138 | ||
|
|
0283556608 | ||
|
|
5063d192b6 | ||
|
|
f4a54e475e | ||
|
|
fb662a62f3 | ||
|
|
587c8c9b44 | ||
|
|
aa900213e7 | ||
|
|
cb26187450 | ||
|
|
625ce23f6b | ||
|
|
dbf4a281f1 | ||
|
|
4ab527e779 | ||
|
|
d0cdeb3247 | ||
|
|
8c610aad9a | ||
|
|
6700381a37 | ||
|
|
20377f889a | ||
|
|
6edb1c5dd0 | ||
|
|
b093f92182 | ||
|
|
fa3bb6ba5c | ||
|
|
2705e7c9b0 |
127
README.md
127
README.md
@@ -47,7 +47,7 @@ clip = CLIP(
|
||||
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
|
||||
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
|
||||
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
|
||||
use_visual_ssl = True, # whether to do self supervised learning on iages
|
||||
use_visual_ssl = True, # whether to do self supervised learning on images
|
||||
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
|
||||
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
|
||||
text_ssl_loss_weight = 0.05, # weight for text MLM loss
|
||||
@@ -110,7 +110,8 @@ decoder = Decoder(
|
||||
unet = unet,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
@@ -229,7 +230,8 @@ decoder = Decoder(
|
||||
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
@@ -348,7 +350,8 @@ decoder = Decoder(
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
).cuda()
|
||||
|
||||
@@ -430,8 +433,8 @@ images = torch.randn(4, 3, 256, 256).cuda()
|
||||
# precompute the text and image embeddings
|
||||
# here using the diffusion prior class, but could be done with CLIP alone
|
||||
|
||||
clip_image_embeds = diffusion_prior.get_image_embed(images)
|
||||
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
|
||||
clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed
|
||||
clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed
|
||||
|
||||
# feed text and images into diffusion prior network
|
||||
|
||||
@@ -495,14 +498,105 @@ loss.backward()
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
## OpenAI CLIP
|
||||
|
||||
Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.
|
||||
|
||||
To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
|
||||
|
||||
# openai pretrained clip - defaults to ViT/B-32
|
||||
|
||||
clip = OpenAIClipAdapter()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# prior networks (with transformer)
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
).cuda()
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
loss = diffusion_prior(text, images)
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps ...
|
||||
|
||||
# decoder (with unet)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 16,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
|
||||
dalle2 = DALLE2(
|
||||
prior = diffusion_prior,
|
||||
decoder = decoder
|
||||
)
|
||||
|
||||
images = dalle2(
|
||||
['a butterfly trying to escape a tornado'],
|
||||
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
|
||||
)
|
||||
|
||||
# save your image (in this example, of size 256x256)
|
||||
```
|
||||
|
||||
Now you'll just have to worry about training the Prior and the Decoder!
|
||||
|
||||
## Experimental
|
||||
|
||||
### DALL-E2 with Latent Diffusion
|
||||
|
||||
This repository decides to take the next step and offer DALL-E2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
|
||||
This repository decides to take the next step and offer DALL-E v2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
|
||||
|
||||
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
|
||||
|
||||
The repository also comes equipped with all the necessary settings to recreate `ViT-VQGan` from the <a href="https://arxiv.org/abs/2110.04627">Improved VQGans</a> paper. Furthermore, the <a href="https://github.com/lucidrains/vector-quantize-pytorch">vector quantization</a> library also comes equipped to do <a href="https://arxiv.org/abs/2203.01941">residual or multi-headed quantization</a>, which I believe will give an even further boost in performance to the autoencoder.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
|
||||
@@ -526,7 +620,7 @@ clip = CLIP(
|
||||
# 3 unets for the decoder (a la cascading DDPM)
|
||||
|
||||
# first two unets are doing latent diffusion
|
||||
# vqgan-vae must be trained before hand
|
||||
# vqgan-vae must be trained beforehand
|
||||
|
||||
vae1 = VQGanVAE(
|
||||
dim = 32,
|
||||
@@ -579,7 +673,8 @@ decoder = Decoder(
|
||||
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
@@ -645,11 +740,13 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
||||
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
|
||||
@@ -697,16 +794,6 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Arar2021LearnedQF,
|
||||
title = {Learned Queries for Efficient Local Attention},
|
||||
author = {Moab Arar and Ariel Shamir and Amit H. Bermano},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2112.11435}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Yu2021VectorquantizedIM,
|
||||
title = {Vector-quantized Image Modeling with Improved VQGAN},
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from x_clip import CLIP
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
class LayerNormChan(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
eps = 1e-5
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
||||
|
||||
# attention-based upsampling
|
||||
# from https://arxiv.org/abs/2112.11435
|
||||
|
||||
class QueryAndAttend(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
num_queries = 1,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
window_size = 3
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
self.window_size = window_size
|
||||
self.num_queries = num_queries
|
||||
|
||||
self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
|
||||
|
||||
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
|
||||
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim * 2, 1, bias = False),
|
||||
nn.Tanh(),
|
||||
nn.Conv2d(dim * 2, dim, 1, bias = False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
l - num queries
|
||||
d - head dimension
|
||||
x - height
|
||||
y - width
|
||||
j - source sequence for attending to (kernel size squared in this case)
|
||||
"""
|
||||
|
||||
wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
|
||||
batch, _, height, width = x.shape
|
||||
|
||||
is_one_query = self.num_queries == 1
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
q = self.queries * self.scale
|
||||
k, v = self.to_kv(x).chunk(2, dim = 1)
|
||||
|
||||
# similarities
|
||||
|
||||
sim = einsum('h l d, b d x y -> b h l x y', q, k)
|
||||
sim = rearrange(sim, 'b ... x y -> b (...) x y')
|
||||
|
||||
# unfold the similarity scores, with float(-inf) as padding value
|
||||
|
||||
mask_value = -torch.finfo(sim.dtype).max
|
||||
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
|
||||
sim = F.unfold(sim, kernel_size = wsz)
|
||||
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)
|
||||
|
||||
# rel pos bias
|
||||
|
||||
sim = sim + self.rel_pos_bias
|
||||
|
||||
# numerically stable attention
|
||||
|
||||
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -3)
|
||||
|
||||
# unfold values
|
||||
|
||||
v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
|
||||
v = F.unfold(v, kernel_size = wsz)
|
||||
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)
|
||||
|
||||
# combine heads
|
||||
|
||||
out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
|
||||
out = self.to_out(out)
|
||||
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)
|
||||
|
||||
# return original input if one query
|
||||
|
||||
if is_one_query:
|
||||
out = rearrange(out, 'b 1 ... -> b ...')
|
||||
|
||||
return out
|
||||
|
||||
class QueryAttnUpsample(nn.Module):
|
||||
def __init__(self, dim, **kwargs):
|
||||
super().__init__()
|
||||
self.norm = LayerNormChan(dim)
|
||||
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
out = self.qna(x)
|
||||
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
|
||||
return out
|
||||
@@ -3,10 +3,12 @@ from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
from functools import partial
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
import torchvision.transforms as T
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -17,7 +19,6 @@ from kornia.filters import gaussian_blur2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||
from dalle2_pytorch.attention import QueryAttnUpsample
|
||||
|
||||
# use x-clip
|
||||
|
||||
@@ -36,6 +37,10 @@ def default(val, d):
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
@contextmanager
|
||||
def null_context(*args, **kwargs):
|
||||
yield
|
||||
|
||||
def eval_decorator(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
was_training = model.training
|
||||
@@ -86,6 +91,144 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
|
||||
|
||||
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
|
||||
|
||||
# image normalization functions
|
||||
# ddpms expect images to be in the range of -1 to 1
|
||||
# but CLIP may otherwise
|
||||
|
||||
def normalize_img(img):
|
||||
return img * 2 - 1
|
||||
|
||||
def unnormalize_img(normed_img):
|
||||
return (normed_img + 1) * 0.5
|
||||
|
||||
# clip related adapters
|
||||
|
||||
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
|
||||
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
||||
|
||||
class BaseClipAdapter(nn.Module):
|
||||
def __init__(self, clip):
|
||||
super().__init__()
|
||||
self.clip = clip
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def embed_text(self, text):
|
||||
raise NotImplementedError
|
||||
|
||||
def embed_image(self, image):
|
||||
raise NotImplementedError
|
||||
|
||||
class XClipAdapter(BaseClipAdapter):
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return self.clip.dim_latent
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
return self.clip.image_size
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
return self.clip.image_channels
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
return self.clip.text_seq_len
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
encoder_output = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
encoder_output = self.clip.visual_transformer(image)
|
||||
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return EmbeddedImage(l2norm(image_embed), image_encodings)
|
||||
|
||||
class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
name = 'ViT-B/32'
|
||||
):
|
||||
import clip
|
||||
openai_clip, preprocess = clip.load(name)
|
||||
super().__init__(openai_clip)
|
||||
|
||||
text_attention_final = self.find_layer('ln_final')
|
||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||
self.clip_normalize = preprocess.transforms[-1]
|
||||
self.cleared = False
|
||||
|
||||
def find_layer(self, layer):
|
||||
modules = dict([*self.clip.named_modules()])
|
||||
return modules.get(layer, None)
|
||||
|
||||
def clear(self):
|
||||
if self.cleared:
|
||||
return
|
||||
|
||||
self.handle()
|
||||
|
||||
def _hook(self, _, inputs, outputs):
|
||||
self.text_encodings = outputs
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return 512
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
return self.clip.visual.input_resolution
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
return 3
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
return self.clip.context_length
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
assert not self.cleared
|
||||
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
del self.text_encodings
|
||||
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.clip_normalize(unnormalize_img(image))
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(image_embed.float(), None)
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
def prob_mask_like(shape, prob, device):
|
||||
@@ -166,7 +309,18 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
|
||||
if loss_type == 'l1':
|
||||
loss_fn = F.l1_loss
|
||||
elif loss_type == 'l2':
|
||||
loss_fn = F.mse_loss
|
||||
elif loss_type == 'huber':
|
||||
loss_fn = F.smooth_l1_loss
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.loss_type = loss_type
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
self.register_buffer('betas', betas)
|
||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
@@ -530,14 +684,14 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
|
||||
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
|
||||
keep_mask = rearrange(keep_mask, 'b -> b 1')
|
||||
|
||||
mask &= cond_prob_mask
|
||||
mask &= keep_mask
|
||||
|
||||
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
||||
|
||||
mask = torch.cat((mask, cond_prob_mask), dim = 1)
|
||||
mask = torch.cat((mask, keep_mask), dim = 1)
|
||||
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
@@ -590,7 +744,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
)
|
||||
|
||||
if exists(clip):
|
||||
assert isinstance(clip, CLIP)
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
else:
|
||||
@@ -607,29 +764,6 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
self.predict_x_start = predict_x_start
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
assert exists(self.clip)
|
||||
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_cond(self, text):
|
||||
assert exists(self.clip)
|
||||
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
text_embed = l2norm(text_embed)
|
||||
|
||||
if not self.condition_on_text_encodings:
|
||||
return dict(text_embed = text_embed)
|
||||
|
||||
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
pred = self.net(x, t, **text_cond)
|
||||
|
||||
@@ -666,29 +800,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||
return img
|
||||
|
||||
def p_losses(self, image_embed, t, text_cond, noise = None):
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
|
||||
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
|
||||
image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||
|
||||
x_recon = self.net(
|
||||
pred = self.net(
|
||||
image_embed_noisy,
|
||||
t,
|
||||
times,
|
||||
cond_drop_prob = self.cond_drop_prob,
|
||||
**text_cond
|
||||
)
|
||||
|
||||
to_predict = noise if not self.predict_x_start else image_embed
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
loss = F.l1_loss(to_predict, x_recon)
|
||||
elif self.loss_type == 'l2':
|
||||
loss = F.mse_loss(to_predict, x_recon)
|
||||
elif self.loss_type == "huber":
|
||||
loss = F.smooth_l1_loss(to_predict, x_recon)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
target = noise if not self.predict_x_start else image_embed
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -701,7 +827,12 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
batch_size = text.shape[0]
|
||||
image_embed_dim = self.image_embed_dim
|
||||
|
||||
text_cond = self.get_text_cond(text)
|
||||
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
text_embeds = text_cond['text_embed']
|
||||
@@ -733,18 +864,18 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
||||
|
||||
if exists(image):
|
||||
image_embed = self.get_image_embed(image)
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
|
||||
# calculate text conditionings, based on what is passed in
|
||||
|
||||
if exists(text):
|
||||
text_cond = self.get_text_cond(text)
|
||||
else:
|
||||
text_cond = dict(
|
||||
text_embed = text_embed,
|
||||
text_encodings = text_encodings,
|
||||
mask = text_mask
|
||||
)
|
||||
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
# timestep conditioning from ddpm
|
||||
|
||||
@@ -753,8 +884,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
# calculate forward loss
|
||||
|
||||
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
||||
return loss
|
||||
return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
||||
|
||||
# decoder
|
||||
|
||||
@@ -971,6 +1101,8 @@ class Unet(nn.Module):
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||
|
||||
# attention related params
|
||||
@@ -1024,13 +1156,14 @@ class Unet(nn.Module):
|
||||
self,
|
||||
*,
|
||||
lowres_cond,
|
||||
channels
|
||||
channels,
|
||||
cond_on_image_embeds
|
||||
):
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels:
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
|
||||
return self
|
||||
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
|
||||
return self.__class__(**updated_kwargs)
|
||||
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
|
||||
return self.__class__(**{**self._locals, **updated_kwargs})
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
@@ -1043,7 +1176,7 @@ class Unet(nn.Module):
|
||||
if cond_scale == 1:
|
||||
return logits
|
||||
|
||||
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
||||
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -1054,7 +1187,9 @@ class Unet(nn.Module):
|
||||
image_embed,
|
||||
lowres_cond_img = None,
|
||||
text_encodings = None,
|
||||
cond_drop_prob = 0.,
|
||||
text_mask = None,
|
||||
image_cond_drop_prob = 0.,
|
||||
text_cond_drop_prob = 0.,
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None
|
||||
):
|
||||
@@ -1073,8 +1208,10 @@ class Unet(nn.Module):
|
||||
|
||||
# conditional dropout
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
||||
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
||||
|
||||
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
|
||||
|
||||
# mask out image embedding depending on condition dropout
|
||||
# for classifier free guidance
|
||||
@@ -1085,7 +1222,7 @@ class Unet(nn.Module):
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
|
||||
image_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
image_keep_mask,
|
||||
image_tokens,
|
||||
self.null_image_embed
|
||||
)
|
||||
@@ -1095,11 +1232,25 @@ class Unet(nn.Module):
|
||||
text_tokens = None
|
||||
|
||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||
text_tokens = text_tokens[:, :self.max_text_len]
|
||||
|
||||
text_tokens_len = text_tokens.shape[1]
|
||||
remainder = self.max_text_len - text_tokens_len
|
||||
|
||||
if remainder > 0:
|
||||
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
|
||||
|
||||
if exists(text_mask):
|
||||
if remainder > 0:
|
||||
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
||||
|
||||
text_keep_mask &= text_mask
|
||||
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
text_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
text_keep_mask,
|
||||
text_tokens,
|
||||
self.null_text_embed[:, :text_tokens.shape[1]]
|
||||
self.null_text_embed
|
||||
)
|
||||
|
||||
# main conditioning tokens (c)
|
||||
@@ -1167,7 +1318,7 @@ class LowresConditioner(nn.Module):
|
||||
target_image_size = cast_tuple(target_image_size, 2)
|
||||
|
||||
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, mode = self.cond_upsample_mode)
|
||||
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
@@ -1187,7 +1338,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
clip,
|
||||
vae = tuple(),
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
loss_type = 'l1',
|
||||
beta_schedule = 'cosine',
|
||||
predict_x_start = False,
|
||||
@@ -1205,8 +1357,12 @@ class Decoder(BaseGaussianDiffusion):
|
||||
loss_type = loss_type
|
||||
)
|
||||
|
||||
assert isinstance(clip, CLIP)
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
|
||||
freeze_model_and_make_eval_(clip)
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
|
||||
self.clip = clip
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
@@ -1233,6 +1389,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
cond_on_image_embeds = is_first,
|
||||
channels = unet_channels
|
||||
)
|
||||
|
||||
@@ -1266,7 +1423,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
self.image_cond_drop_prob = image_cond_drop_prob
|
||||
self.text_cond_drop_prob = text_cond_drop_prob
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
@@ -1287,21 +1445,14 @@ class Decoder(BaseGaussianDiffusion):
|
||||
yield
|
||||
unet.cpu()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_encodings(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
return text_encodings[:, 1:]
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image = resize_image_to(image, self.clip_image_size)
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed)
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
return image_embed
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
@@ -1315,16 +1466,16 @@ class Decoder(BaseGaussianDiffusion):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
@@ -1337,6 +1488,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
torch.full((b,), i, device = device, dtype = torch.long),
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
text_mask = text_mask,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
predict_x_start = predict_x_start
|
||||
@@ -1344,31 +1496,25 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
return img
|
||||
|
||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
|
||||
x_recon = unet(
|
||||
pred = unet(
|
||||
x_noisy,
|
||||
t,
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
text_mask = text_mask,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
cond_drop_prob = self.cond_drop_prob
|
||||
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
)
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
loss = F.l1_loss(target, x_recon)
|
||||
elif self.loss_type == 'l2':
|
||||
loss = F.mse_loss(target, x_recon)
|
||||
elif self.loss_type == "huber":
|
||||
loss = F.smooth_l1_loss(target, x_recon)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -1376,14 +1522,20 @@ class Decoder(BaseGaussianDiffusion):
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
batch_size = image_embed.shape[0]
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
text_encodings = text_mask = None
|
||||
if exists(text):
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
img = None
|
||||
|
||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
with self.one_unet_in_gpu(unet = unet):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
||||
|
||||
with context:
|
||||
lowres_cond_img = None
|
||||
shape = (batch_size, channel, image_size, image_size)
|
||||
|
||||
@@ -1401,6 +1553,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
shape,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
text_mask = text_mask,
|
||||
cond_scale = cond_scale,
|
||||
predict_x_start = predict_x_start,
|
||||
lowres_cond_img = lowres_cond_img
|
||||
@@ -1436,11 +1589,14 @@ class Decoder(BaseGaussianDiffusion):
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
|
||||
if not exists(image_embed):
|
||||
image_embed = self.get_image_embed(image)
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||
text_encodings = text_mask = None
|
||||
if exists(text) and not exists(text_encodings):
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
image = resize_image_to(image, target_image_size)
|
||||
@@ -1452,7 +1608,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||
|
||||
# main class
|
||||
|
||||
@@ -1473,12 +1629,15 @@ class DALLE2(nn.Module):
|
||||
self.prior_num_samples = prior_num_samples
|
||||
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
|
||||
|
||||
self.to_pil = T.ToPILImage()
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def forward(
|
||||
self,
|
||||
text,
|
||||
cond_scale = 1.
|
||||
cond_scale = 1.,
|
||||
return_pil_images = False
|
||||
):
|
||||
device = next(self.parameters()).device
|
||||
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
|
||||
@@ -1492,6 +1651,9 @@ class DALLE2(nn.Module):
|
||||
text_cond = text if self.decoder_need_text_cond else None
|
||||
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
||||
|
||||
if return_pil_images:
|
||||
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
||||
|
||||
if one_text:
|
||||
return images[0]
|
||||
|
||||
|
||||
29
dalle2_pytorch/optimizer.py
Normal file
29
dalle2_pytorch/optimizer.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from torch.optim import AdamW, Adam
|
||||
|
||||
def separate_weight_decayable_params(params):
|
||||
no_wd_params = set([param for param in params if param.ndim < 2])
|
||||
wd_params = set(params) - no_wd_params
|
||||
return wd_params, no_wd_params
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
filter_by_requires_grad = False
|
||||
):
|
||||
if filter_by_requires_grad:
|
||||
params = list(filter(lambda t: t.requires_grad, params))
|
||||
|
||||
if wd == 0:
|
||||
return Adam(params, lr = lr, betas = betas)
|
||||
|
||||
params = set(params)
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
|
||||
param_groups = [
|
||||
{'params': list(wd_params)},
|
||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||
]
|
||||
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
|
||||
@@ -35,7 +35,7 @@ class EMA(nn.Module):
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
def update_moving_average(ma_model, current_model):
|
||||
def update_moving_average(self, ma_model, current_model):
|
||||
def calculate_ema(beta, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
|
||||
@@ -15,8 +15,6 @@ from einops import rearrange, reduce, repeat
|
||||
from einops_exts import rearrange_many
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from dalle2_pytorch.attention import QueryAttnUpsample
|
||||
|
||||
# constants
|
||||
|
||||
MList = nn.ModuleList
|
||||
@@ -547,6 +545,7 @@ class VQGanVAE(nn.Module):
|
||||
l2_recon_loss = False,
|
||||
use_hinge_loss = True,
|
||||
vgg = None,
|
||||
vq_codebook_dim = 256,
|
||||
vq_codebook_size = 512,
|
||||
vq_decay = 0.8,
|
||||
vq_commitment_weight = 1.,
|
||||
@@ -581,6 +580,7 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
self.vq = VQ(
|
||||
dim = self.enc_dec.encoded_dim,
|
||||
codebook_dim = vq_codebook_dim,
|
||||
codebook_size = vq_codebook_size,
|
||||
decay = vq_decay,
|
||||
commitment_weight = vq_commitment_weight,
|
||||
|
||||
5
setup.py
5
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.54',
|
||||
version = '0.0.75',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -23,6 +23,7 @@ setup(
|
||||
],
|
||||
install_requires=[
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'einops>=0.4',
|
||||
'einops-exts>=0.0.3',
|
||||
'kornia>=0.5.4',
|
||||
@@ -31,7 +32,7 @@ setup(
|
||||
'torchvision',
|
||||
'tqdm',
|
||||
'vector-quantize-pytorch',
|
||||
'x-clip>=0.4.4',
|
||||
'x-clip>=0.5.1',
|
||||
'youtokentome'
|
||||
],
|
||||
classifiers=[
|
||||
|
||||
Reference in New Issue
Block a user