mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
846162ef3e | ||
|
|
39d3659ad9 | ||
|
|
f4a54e475e | ||
|
|
fb662a62f3 | ||
|
|
587c8c9b44 | ||
|
|
aa900213e7 | ||
|
|
cb26187450 | ||
|
|
625ce23f6b | ||
|
|
dbf4a281f1 | ||
|
|
4ab527e779 | ||
|
|
d0cdeb3247 | ||
|
|
8c610aad9a | ||
|
|
6700381a37 | ||
|
|
20377f889a |
93
README.md
93
README.md
@@ -495,6 +495,96 @@ loss.backward()
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
## OpenAI CLIP
|
||||
|
||||
Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.
|
||||
|
||||
First you'll need to install <a href="https://github.com/openai/CLIP#usage">the prerequisites</a>
|
||||
|
||||
Then to use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
|
||||
|
||||
# openai pretrained clip - defaults to ViT/B-32
|
||||
|
||||
clip = OpenAIClipAdapter()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# prior networks (with transformer)
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
).cuda()
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
loss = diffusion_prior(text, images)
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps ...
|
||||
|
||||
# decoder (with unet)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 16,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
|
||||
dalle2 = DALLE2(
|
||||
prior = diffusion_prior,
|
||||
decoder = decoder
|
||||
)
|
||||
|
||||
images = dalle2(
|
||||
['a butterfly trying to escape a tornado'],
|
||||
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
|
||||
)
|
||||
|
||||
# save your image (in this example, of size 256x256)
|
||||
```
|
||||
|
||||
Now you'll just have to worry about training the Prior and the Decoder!
|
||||
|
||||
## Experimental
|
||||
|
||||
### DALL-E2 with Latent Diffusion
|
||||
@@ -647,11 +737,12 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
||||
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from x_clip import CLIP
|
||||
|
||||
@@ -7,6 +7,7 @@ from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
import torchvision.transforms as T
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -89,6 +90,145 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
|
||||
|
||||
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
|
||||
|
||||
# image normalization functions
|
||||
# ddpms expect images to be in the range of -1 to 1
|
||||
# but CLIP may otherwise
|
||||
|
||||
def normalize_img(img):
|
||||
return img * 2 - 1
|
||||
|
||||
def unnormalize_img(normed_img):
|
||||
return (normed_img + 1) * 0.5
|
||||
|
||||
# clip related adapters
|
||||
|
||||
class BaseClipAdapter(nn.Module):
|
||||
def __init__(self, clip):
|
||||
super().__init__()
|
||||
self.clip = clip
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def embed_text(self, text):
|
||||
raise NotImplementedError
|
||||
|
||||
def embed_image(self, image):
|
||||
raise NotImplementedError
|
||||
|
||||
class XClipAdapter(BaseClipAdapter):
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return self.clip.dim_latent
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
return self.clip.image_size
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
return self.clip.image_channels
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
return self.clip.text_seq_len
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
encoder_output = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
return l2norm(text_embed), text_encodings, text_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
encoder_output = self.clip.visual_transformer(image)
|
||||
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed), image_encodings
|
||||
|
||||
class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
name = 'ViT-B/32'
|
||||
):
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
|
||||
|
||||
openai_clip, _ = clip.load(name)
|
||||
super().__init__(openai_clip)
|
||||
|
||||
text_attention_final = self.find_layer('ln_final')
|
||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
self.cleared = False
|
||||
|
||||
def find_layer(self, layer):
|
||||
modules = dict([*self.clip.named_modules()])
|
||||
return modules.get(layer, None)
|
||||
|
||||
def clear(self):
|
||||
if self.cleared:
|
||||
return
|
||||
|
||||
self.handle()
|
||||
|
||||
def _hook(self, _, inputs, outputs):
|
||||
self.text_encodings = outputs
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return 512
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
return self.clip.visual.input_resolution
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
return 3
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
return self.clip.context_length
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
assert not self.cleared
|
||||
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
del self.text_encodings
|
||||
return text_embed.float(), text_encodings.float(), text_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.clip_normalize(unnormalize_img(image))
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return image_embed.float(), None
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
def prob_mask_like(shape, prob, device):
|
||||
@@ -169,7 +309,18 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
|
||||
if loss_type == 'l1':
|
||||
loss_fn = F.l1_loss
|
||||
elif loss_type == 'l2':
|
||||
loss_fn = F.mse_loss
|
||||
elif loss_type == 'huber':
|
||||
loss_fn = F.smooth_l1_loss
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.loss_type = loss_type
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
self.register_buffer('betas', betas)
|
||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
@@ -593,7 +744,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
)
|
||||
|
||||
if exists(clip):
|
||||
assert isinstance(clip, CLIP)
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
else:
|
||||
@@ -610,29 +764,6 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
self.predict_x_start = predict_x_start
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
assert exists(self.clip)
|
||||
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_cond(self, text):
|
||||
assert exists(self.clip)
|
||||
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
text_embed = l2norm(text_embed)
|
||||
|
||||
if not self.condition_on_text_encodings:
|
||||
return dict(text_embed = text_embed)
|
||||
|
||||
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
pred = self.net(x, t, **text_cond)
|
||||
|
||||
@@ -669,29 +800,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||
return img
|
||||
|
||||
def p_losses(self, image_embed, t, text_cond, noise = None):
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
|
||||
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
|
||||
image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||
|
||||
x_recon = self.net(
|
||||
pred = self.net(
|
||||
image_embed_noisy,
|
||||
t,
|
||||
times,
|
||||
cond_drop_prob = self.cond_drop_prob,
|
||||
**text_cond
|
||||
)
|
||||
|
||||
to_predict = noise if not self.predict_x_start else image_embed
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
loss = F.l1_loss(to_predict, x_recon)
|
||||
elif self.loss_type == 'l2':
|
||||
loss = F.mse_loss(to_predict, x_recon)
|
||||
elif self.loss_type == "huber":
|
||||
loss = F.smooth_l1_loss(to_predict, x_recon)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
target = noise if not self.predict_x_start else image_embed
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -704,7 +827,12 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
batch_size = text.shape[0]
|
||||
image_embed_dim = self.image_embed_dim
|
||||
|
||||
text_cond = self.get_text_cond(text)
|
||||
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
text_embeds = text_cond['text_embed']
|
||||
@@ -736,18 +864,18 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
||||
|
||||
if exists(image):
|
||||
image_embed = self.get_image_embed(image)
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
|
||||
# calculate text conditionings, based on what is passed in
|
||||
|
||||
if exists(text):
|
||||
text_cond = self.get_text_cond(text)
|
||||
else:
|
||||
text_cond = dict(
|
||||
text_embed = text_embed,
|
||||
text_encodings = text_encodings,
|
||||
mask = text_mask
|
||||
)
|
||||
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
# timestep conditioning from ddpm
|
||||
|
||||
@@ -756,8 +884,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
# calculate forward loss
|
||||
|
||||
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
||||
return loss
|
||||
return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
||||
|
||||
# decoder
|
||||
|
||||
@@ -1027,13 +1154,14 @@ class Unet(nn.Module):
|
||||
self,
|
||||
*,
|
||||
lowres_cond,
|
||||
channels
|
||||
channels,
|
||||
cond_on_image_embeds
|
||||
):
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels:
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
|
||||
return self
|
||||
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
|
||||
return self.__class__(**updated_kwargs)
|
||||
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
|
||||
return self.__class__(**{**self._locals, **updated_kwargs})
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
@@ -1170,7 +1298,7 @@ class LowresConditioner(nn.Module):
|
||||
target_image_size = cast_tuple(target_image_size, 2)
|
||||
|
||||
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, mode = self.cond_upsample_mode)
|
||||
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
@@ -1208,8 +1336,12 @@ class Decoder(BaseGaussianDiffusion):
|
||||
loss_type = loss_type
|
||||
)
|
||||
|
||||
assert isinstance(clip, CLIP)
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
|
||||
freeze_model_and_make_eval_(clip)
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
|
||||
self.clip = clip
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
@@ -1236,6 +1368,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
cond_on_image_embeds = is_first,
|
||||
channels = unet_channels
|
||||
)
|
||||
|
||||
@@ -1290,18 +1423,11 @@ class Decoder(BaseGaussianDiffusion):
|
||||
yield
|
||||
unet.cpu()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_encodings(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
return text_encodings[:, 1:]
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image = resize_image_to(image, self.clip_image_size)
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed)
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
return image_embed
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
@@ -1347,14 +1473,14 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
return img
|
||||
|
||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
|
||||
x_recon = unet(
|
||||
pred = unet(
|
||||
x_noisy,
|
||||
t,
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
@@ -1363,15 +1489,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
loss = F.l1_loss(target, x_recon)
|
||||
elif self.loss_type == 'l2':
|
||||
loss = F.mse_loss(target, x_recon)
|
||||
elif self.loss_type == "huber":
|
||||
loss = F.smooth_l1_loss(target, x_recon)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -1379,9 +1497,12 @@ class Decoder(BaseGaussianDiffusion):
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
batch_size = image_embed.shape[0]
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
text_encodings = None
|
||||
if exists(text):
|
||||
_, text_encodings, _ = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
img = None
|
||||
|
||||
@@ -1442,11 +1563,14 @@ class Decoder(BaseGaussianDiffusion):
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
|
||||
if not exists(image_embed):
|
||||
image_embed = self.get_image_embed(image)
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||
text_encodings = None
|
||||
if exists(text) and not exists(text_encodings):
|
||||
_, text_encodings, _ = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
image = resize_image_to(image, target_image_size)
|
||||
@@ -1479,12 +1603,15 @@ class DALLE2(nn.Module):
|
||||
self.prior_num_samples = prior_num_samples
|
||||
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
|
||||
|
||||
self.to_pil = T.ToPILImage()
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def forward(
|
||||
self,
|
||||
text,
|
||||
cond_scale = 1.
|
||||
cond_scale = 1.,
|
||||
return_pil_images = False
|
||||
):
|
||||
device = next(self.parameters()).device
|
||||
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
|
||||
@@ -1498,7 +1625,11 @@ class DALLE2(nn.Module):
|
||||
text_cond = text if self.decoder_need_text_cond else None
|
||||
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
||||
|
||||
if return_pil_images:
|
||||
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
||||
|
||||
if one_text:
|
||||
return images[0]
|
||||
|
||||
return images
|
||||
|
||||
|
||||
84
dalle2_pytorch/openai_clip.py
Normal file
84
dalle2_pytorch/openai_clip.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import BaseClipAdapter
|
||||
import torchvision.transforms as T
|
||||
|
||||
def find_layer(model, layer):
|
||||
modules = dict([*model.named_modules()])
|
||||
return modules.get(layer, None)
|
||||
|
||||
def hook(_, input, output):
|
||||
print(output.shape)
|
||||
|
||||
import clip
|
||||
# image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
|
||||
text = clip.tokenize(["a diagram", "a dog", "a cat"]).cuda()
|
||||
image = torch.randn(1, 3, 224, 224).cuda()
|
||||
|
||||
|
||||
class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def __init__(self, name = 'ViT-B/32'):
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
|
||||
|
||||
openai_clip, _ = clip.load(name)
|
||||
super().__init__(openai_clip)
|
||||
|
||||
text_attention_final = self.find_layer(self.clip, 'ln_final')
|
||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
self.cleared = False
|
||||
|
||||
def find_layer(self, layer):
|
||||
modules = dict([*self.clip.named_modules()])
|
||||
return modules.get(layer, None)
|
||||
|
||||
def clear(self):
|
||||
if self.cleared:
|
||||
return
|
||||
|
||||
self.handle()
|
||||
|
||||
def _hook(self, _, inputs, outputs):
|
||||
self.text_encodings = outputs
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return 512
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
return self.clip.visual.input_resolution
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
return 3
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
assert not self.cleared
|
||||
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
del self.text_encodings
|
||||
return text_embed, text_encodings
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
|
||||
image = self.clip_normalize(image)
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return image_embed, None
|
||||
|
||||
clip_adapter = OpenAIClipAdapter().cuda()
|
||||
|
||||
# print(model)
|
||||
with torch.no_grad():
|
||||
image_features, _ = clip_adapter.embed_image(image)
|
||||
text_features, text_encodings = clip_adapter.embed_text(text)
|
||||
print(text_features.shape, image_features.shape)
|
||||
print(text_encodings.shape)
|
||||
29
dalle2_pytorch/optimizer.py
Normal file
29
dalle2_pytorch/optimizer.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from torch.optim import AdamW, Adam
|
||||
|
||||
def separate_weight_decayable_params(params):
|
||||
no_wd_params = set([param for param in params if param.ndim < 2])
|
||||
wd_params = set(params) - no_wd_params
|
||||
return wd_params, no_wd_params
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
filter_by_requires_grad = False
|
||||
):
|
||||
if filter_by_requires_grad:
|
||||
params = list(filter(lambda t: t.requires_grad, params))
|
||||
|
||||
if wd == 0:
|
||||
return Adam(params, lr = lr, betas = betas)
|
||||
|
||||
params = set(params)
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
|
||||
param_groups = [
|
||||
{'params': list(wd_params)},
|
||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||
]
|
||||
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
|
||||
@@ -545,6 +545,7 @@ class VQGanVAE(nn.Module):
|
||||
l2_recon_loss = False,
|
||||
use_hinge_loss = True,
|
||||
vgg = None,
|
||||
vq_codebook_dim = 256,
|
||||
vq_codebook_size = 512,
|
||||
vq_decay = 0.8,
|
||||
vq_commitment_weight = 1.,
|
||||
@@ -579,6 +580,7 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
self.vq = VQ(
|
||||
dim = self.enc_dec.encoded_dim,
|
||||
codebook_dim = vq_codebook_dim,
|
||||
codebook_size = vq_codebook_size,
|
||||
decay = vq_decay,
|
||||
commitment_weight = vq_commitment_weight,
|
||||
|
||||
4
setup.py
4
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.56',
|
||||
version = '0.0.70',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -31,7 +31,7 @@ setup(
|
||||
'torchvision',
|
||||
'tqdm',
|
||||
'vector-quantize-pytorch',
|
||||
'x-clip>=0.4.4',
|
||||
'x-clip>=0.5.1',
|
||||
'youtokentome'
|
||||
],
|
||||
classifiers=[
|
||||
|
||||
Reference in New Issue
Block a user