Compare commits

...

17 Commits

Author SHA1 Message Date
Phil Wang
5fff22834e be able to finely customize learning parameters for each unet, take care of gradient clipping 2022-04-30 11:56:05 -07:00
Phil Wang
a9421f49ec simplify Decoder training for the public 2022-04-30 11:45:18 -07:00
Phil Wang
77fa34eae9 fix all clipping / clamping issues 2022-04-30 10:08:24 -07:00
Phil Wang
1c1e508369 fix all issues with text encodings conditioning in the decoder, using null padding tokens technique from dalle v1 2022-04-30 09:13:34 -07:00
Phil Wang
f19c99ecb0 fix decoder needing separate conditional dropping probabilities for image embeddings and text encodings, thanks to @xiankgx ! 2022-04-30 08:48:05 -07:00
Phil Wang
721a444686 Merge pull request #37 from ProGamerGov/patch-1
Fix spelling and grammatical errors
2022-04-30 08:19:07 -07:00
ProGamerGov
63450b466d Fix spelling and grammatical errors 2022-04-30 09:18:13 -06:00
Phil Wang
20e7eb5a9b cleanup 2022-04-30 07:22:57 -07:00
Phil Wang
e2f9615afa use @clip-anytorch , thanks to @rom1504 2022-04-30 06:40:54 -07:00
Phil Wang
0d1c07c803 fix a bug with classifier free guidance, thanks to @xiankgx again! 2022-04-30 06:34:57 -07:00
Phil Wang
a389f81138 todo 2022-04-29 15:40:51 -07:00
Phil Wang
0283556608 fix example in readme, since api changed 2022-04-29 13:40:55 -07:00
Phil Wang
5063d192b6 now completely OpenAI CLIP compatible for training
just take care of the logic for AdamW and transformers

used namedtuples for clip adapter embedding outputs
2022-04-29 13:05:01 -07:00
Phil Wang
f4a54e475e add some training fns 2022-04-29 09:44:55 -07:00
Phil Wang
fb662a62f3 fix another bug thanks to @xiankgx 2022-04-29 07:38:32 -07:00
Phil Wang
587c8c9b44 optimize for clarity 2022-04-28 21:59:13 -07:00
Phil Wang
aa900213e7 force first unet in the cascade to be conditioned on image embeds 2022-04-28 20:53:15 -07:00
6 changed files with 515 additions and 87 deletions

183
README.md
View File

@@ -47,7 +47,7 @@ clip = CLIP(
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP) use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL) decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB) extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
use_visual_ssl = True, # whether to do self supervised learning on iages use_visual_ssl = True, # whether to do self supervised learning on images
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP) use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
text_ssl_loss_weight = 0.05, # weight for text MLM loss text_ssl_loss_weight = 0.05, # weight for text MLM loss
@@ -110,7 +110,8 @@ decoder = Decoder(
unet = unet, unet = unet,
clip = clip, clip = clip,
timesteps = 100, timesteps = 100,
cond_drop_prob = 0.2 image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda() ).cuda()
# mock images (get a lot of this) # mock images (get a lot of this)
@@ -229,7 +230,8 @@ decoder = Decoder(
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here) unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in) image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000, timesteps = 1000,
cond_drop_prob = 0.2 image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda() ).cuda()
# mock images (get a lot of this) # mock images (get a lot of this)
@@ -348,7 +350,8 @@ decoder = Decoder(
image_sizes = (128, 256), image_sizes = (128, 256),
clip = clip, clip = clip,
timesteps = 100, timesteps = 100,
cond_drop_prob = 0.2, image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda() ).cuda()
@@ -430,8 +433,8 @@ images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings # precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone # here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = diffusion_prior.get_image_embed(images) clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed') clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed
# feed text and images into diffusion prior network # feed text and images into diffusion prior network
@@ -495,6 +498,95 @@ loss.backward()
# now the diffusion prior can generate image embeddings from the text embeddings # now the diffusion prior can generate image embeddings from the text embeddings
``` ```
## OpenAI CLIP
Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.
To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so
```python
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
# openai pretrained clip - defaults to ViT/B-32
clip = OpenAIClipAdapter()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# prior networks (with transformer)
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = diffusion_prior(text, images)
loss.backward()
# do above for many steps ...
# decoder (with unet)
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)
images = dalle2(
['a butterfly trying to escape a tornado'],
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image (in this example, of size 256x256)
```
Now you'll just have to worry about training the Prior and the Decoder!
## Experimental ## Experimental
### DALL-E2 with Latent Diffusion ### DALL-E2 with Latent Diffusion
@@ -528,7 +620,7 @@ clip = CLIP(
# 3 unets for the decoder (a la cascading DDPM) # 3 unets for the decoder (a la cascading DDPM)
# first two unets are doing latent diffusion # first two unets are doing latent diffusion
# vqgan-vae must be trained before hand # vqgan-vae must be trained beforehand
vae1 = VQGanVAE( vae1 = VQGanVAE(
dim = 32, dim = 32,
@@ -581,7 +673,8 @@ decoder = Decoder(
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here) unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
timesteps = 100, timesteps = 100,
cond_drop_prob = 0.2 image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda() ).cuda()
# mock images (get a lot of this) # mock images (get a lot of this)
@@ -615,7 +708,77 @@ images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
## Training wrapper (wip) ## Training wrapper (wip)
Offer training wrappers ### Decoder Training
Training the `Decoder` may be confusing, as one needs to keep track of an optimizer for each of the `Unet`(s) separately. Each `Unet` will also need its own corresponding exponential moving average. The `DecoderTrainer` hopes to make this simple, as shown below
```python
import torch
from dalle2_pytorch import DALLE2, Unet, Decoder, CLIP, DecoderTrainer
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# decoder (with unet)
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_text_encodings = True
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 1,
condition_on_text_encodings = True
).cuda()
decoder_trainer = DecoderTrainer(
decoder,
lr = 3e-4,
wd = 1e-2,
ema_beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
)
for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
loss.backward()
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
```
## CLI (wip) ## CLI (wip)
@@ -648,9 +811,11 @@ Once built, images will be saved to the same directory the command is invoked
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms - [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion - [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in - [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
- [ ] take care of mixed precision as well as gradient accumulation within decoder trainer
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder - [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference

View File

@@ -1,4 +1,6 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP from x_clip import CLIP

View File

@@ -3,6 +3,7 @@ from tqdm import tqdm
from inspect import isfunction from inspect import isfunction
from functools import partial from functools import partial
from contextlib import contextmanager from contextlib import contextmanager
from collections import namedtuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -90,8 +91,21 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
return F.interpolate(t, size = shape, mode = mode, align_corners = False) return F.interpolate(t, size = shape, mode = mode, align_corners = False)
# image normalization functions
# ddpms expect images to be in the range of -1 to 1
# but CLIP may otherwise
def normalize_img(img):
return img * 2 - 1
def unnormalize_img(normed_img):
return (normed_img + 1) * 0.5
# clip related adapters # clip related adapters
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
class BaseClipAdapter(nn.Module): class BaseClipAdapter(nn.Module):
def __init__(self, clip): def __init__(self, clip):
super().__init__() super().__init__()
@@ -109,6 +123,10 @@ class BaseClipAdapter(nn.Module):
def image_channels(self): def image_channels(self):
raise NotImplementedError raise NotImplementedError
@property
def max_text_len(self):
raise NotImplementedError
def embed_text(self, text): def embed_text(self, text):
raise NotImplementedError raise NotImplementedError
@@ -128,12 +146,18 @@ class XClipAdapter(BaseClipAdapter):
def image_channels(self): def image_channels(self):
return self.clip.image_channels return self.clip.image_channels
@property
def max_text_len(self):
return self.clip.text_seq_len
@torch.no_grad() @torch.no_grad()
def embed_text(self, text): def embed_text(self, text):
text = text[..., :self.max_text_len]
text_mask = text != 0
encoder_output = self.clip.text_transformer(text) encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:] text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls) text_embed = self.clip.to_text_latent(text_cls)
return l2norm(text_embed), text_encodings return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
@@ -141,7 +165,69 @@ class XClipAdapter(BaseClipAdapter):
encoder_output = self.clip.visual_transformer(image) encoder_output = self.clip.visual_transformer(image)
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:] image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
image_embed = self.clip.to_visual_latent(image_cls) image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed), image_encodings return EmbeddedImage(l2norm(image_embed), image_encodings)
class OpenAIClipAdapter(BaseClipAdapter):
def __init__(
self,
name = 'ViT-B/32'
):
import clip
openai_clip, preprocess = clip.load(name)
super().__init__(openai_clip)
text_attention_final = self.find_layer('ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
def find_layer(self, layer):
modules = dict([*self.clip.named_modules()])
return modules.get(layer, None)
def clear(self):
if self.cleared:
return
self.handle()
def _hook(self, _, inputs, outputs):
self.text_encodings = outputs
@property
def dim_latent(self):
return 512
@property
def image_size(self):
return self.clip.visual.input_resolution
@property
def image_channels(self):
return 3
@property
def max_text_len(self):
return self.clip.context_length
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
text_mask = text != 0
assert not self.cleared
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
del self.text_encodings
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
@torch.no_grad()
def embed_image(self, image):
assert not self.cleared
image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image))
image_embed = self.clip.encode_image(image)
return EmbeddedImage(image_embed.float(), None)
# classifier free guidance functions # classifier free guidance functions
@@ -223,7 +309,18 @@ class BaseGaussianDiffusion(nn.Module):
timesteps, = betas.shape timesteps, = betas.shape
self.num_timesteps = int(timesteps) self.num_timesteps = int(timesteps)
if loss_type == 'l1':
loss_fn = F.l1_loss
elif loss_type == 'l2':
loss_fn = F.mse_loss
elif loss_type == 'huber':
loss_fn = F.smooth_l1_loss
else:
raise NotImplementedError()
self.loss_type = loss_type self.loss_type = loss_type
self.loss_fn = loss_fn
self.register_buffer('betas', betas) self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod) self.register_buffer('alphas_cumprod', alphas_cumprod)
@@ -587,14 +684,14 @@ class DiffusionPriorNetwork(nn.Module):
# classifier free guidance # classifier free guidance
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device) keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1') keep_mask = rearrange(keep_mask, 'b -> b 1')
mask &= cond_prob_mask mask &= keep_mask
# whether text embedding is masked or not depends on the classifier free guidance conditional masking # whether text embedding is masked or not depends on the classifier free guidance conditional masking
mask = torch.cat((mask, cond_prob_mask), dim = 1) mask = torch.cat((mask, keep_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right # but let's just do it right
@@ -639,6 +736,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
predict_x_start = True, predict_x_start = True,
beta_schedule = "cosine", beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False
): ):
super().__init__( super().__init__(
beta_schedule = beta_schedule, beta_schedule = beta_schedule,
@@ -667,6 +765,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.predict_x_start = predict_x_start self.predict_x_start = predict_x_start
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
# whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond) pred = self.net(x, t, **text_cond)
@@ -680,6 +781,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
if clip_denoised and not self.predict_x_start: if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@@ -703,29 +807,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond) img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img return img
def p_losses(self, image_embed, t, text_cond, noise = None): def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed)) noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise) image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
x_recon = self.net( pred = self.net(
image_embed_noisy, image_embed_noisy,
t, times,
cond_drop_prob = self.cond_drop_prob, cond_drop_prob = self.cond_drop_prob,
**text_cond **text_cond
) )
to_predict = noise if not self.predict_x_start else image_embed target = noise if not self.predict_x_start else image_embed
if self.loss_type == 'l1':
loss = F.l1_loss(to_predict, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(to_predict, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(to_predict, x_recon)
else:
raise NotImplementedError()
loss = self.loss_fn(pred, target)
return loss return loss
@torch.no_grad() @torch.no_grad()
@@ -738,12 +834,12 @@ class DiffusionPrior(BaseGaussianDiffusion):
batch_size = text.shape[0] batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim image_embed_dim = self.image_embed_dim
text_embed, text_encodings = self.clip.embed_text(text) text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed) text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings: if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text != 0} text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed'] text_embeds = text_cond['text_embed']
@@ -780,8 +876,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# calculate text conditionings, based on what is passed in # calculate text conditionings, based on what is passed in
if exists(text): if exists(text):
text_embed, text_encodings = self.clip.embed_text(text) text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_mask = text != 0
text_cond = dict(text_embed = text_embed) text_cond = dict(text_embed = text_embed)
@@ -1002,7 +1097,12 @@ class Unet(nn.Module):
Rearrange('b (n d) -> b n d', n = num_image_tokens) Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity() ) if image_embed_dim != cond_dim else nn.Identity()
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim) # text encoding conditioning (optional)
self.text_to_cond = None
if cond_on_text_encodings:
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
# finer control over whether to condition on image embeddings and text encodings # finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting # so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
@@ -1013,6 +1113,8 @@ class Unet(nn.Module):
# for classifier free guidance # for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
# attention related params # attention related params
@@ -1066,13 +1168,14 @@ class Unet(nn.Module):
self, self,
*, *,
lowres_cond, lowres_cond,
channels channels,
cond_on_image_embeds
): ):
if lowres_cond == self.lowres_cond and channels == self.channels: if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
return self return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels} updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
return self.__class__(**updated_kwargs) return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
@@ -1085,7 +1188,7 @@ class Unet(nn.Module):
if cond_scale == 1: if cond_scale == 1:
return logits return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
@@ -1096,7 +1199,9 @@ class Unet(nn.Module):
image_embed, image_embed,
lowres_cond_img = None, lowres_cond_img = None,
text_encodings = None, text_encodings = None,
cond_drop_prob = 0., text_mask = None,
image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
blur_sigma = None, blur_sigma = None,
blur_kernel_size = None blur_kernel_size = None
): ):
@@ -1115,8 +1220,10 @@ class Unet(nn.Module):
# conditional dropout # conditional dropout
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device) image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1') text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
# mask out image embedding depending on condition dropout # mask out image embedding depending on condition dropout
# for classifier free guidance # for classifier free guidance
@@ -1127,7 +1234,7 @@ class Unet(nn.Module):
image_tokens = self.image_to_cond(image_embed) image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where( image_tokens = torch.where(
cond_prob_mask, image_keep_mask,
image_tokens, image_tokens,
self.null_image_embed self.null_image_embed
) )
@@ -1138,10 +1245,25 @@ class Unet(nn.Module):
if exists(text_encodings) and self.cond_on_text_encodings: if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings) text_tokens = self.text_to_cond(text_encodings)
text_tokens = text_tokens[:, :self.max_text_len]
text_tokens_len = text_tokens.shape[1]
remainder = self.max_text_len - text_tokens_len
if remainder > 0:
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
if exists(text_mask):
if remainder > 0:
text_mask = F.pad(text_mask, (0, remainder), value = False)
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask = text_mask & text_keep_mask
text_tokens = torch.where( text_tokens = torch.where(
cond_prob_mask, text_keep_mask,
text_tokens, text_tokens,
self.null_text_embed[:, :text_tokens.shape[1]] self.null_text_embed
) )
# main conditioning tokens (c) # main conditioning tokens (c)
@@ -1209,7 +1331,7 @@ class LowresConditioner(nn.Module):
target_image_size = cast_tuple(target_image_size, 2) target_image_size = cast_tuple(target_image_size, 2)
if self.training and self.downsample_first and exists(downsample_image_size): if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode) cond_fmap = resize_image_to(cond_fmap, downsample_image_size, mode = self.cond_upsample_mode)
if self.training: if self.training:
# when training, blur the low resolution conditional image # when training, blur the low resolution conditional image
@@ -1229,7 +1351,8 @@ class Decoder(BaseGaussianDiffusion):
clip, clip,
vae = tuple(), vae = tuple(),
timesteps = 1000, timesteps = 1000,
cond_drop_prob = 0.2, image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l1', loss_type = 'l1',
beta_schedule = 'cosine', beta_schedule = 'cosine',
predict_x_start = False, predict_x_start = False,
@@ -1240,6 +1363,8 @@ class Decoder(BaseGaussianDiffusion):
blur_sigma = 0.1, # cascading ddpm - blur sigma blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True,
clip_x_start = True
): ):
super().__init__( super().__init__(
beta_schedule = beta_schedule, beta_schedule = beta_schedule,
@@ -1279,6 +1404,7 @@ class Decoder(BaseGaussianDiffusion):
one_unet = one_unet.cast_model_parameters( one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first, lowres_cond = not is_first,
cond_on_image_embeds = is_first,
channels = unet_channels channels = unet_channels
) )
@@ -1312,7 +1438,13 @@ class Decoder(BaseGaussianDiffusion):
# classifier free guidance # classifier free guidance
self.cond_drop_prob = cond_drop_prob self.image_cond_drop_prob = image_cond_drop_prob
self.text_cond_drop_prob = text_cond_drop_prob
# whether to clip when sampling
self.clip_denoised = clip_denoised
self.clip_x_start = clip_x_start
def get_unet(self, unet_number): def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets) assert 0 < unet_number <= len(self.unets)
@@ -1336,37 +1468,34 @@ class Decoder(BaseGaussianDiffusion):
@torch.no_grad() @torch.no_grad()
def get_image_embed(self, image): def get_image_embed(self, image):
image = resize_image_to(image, self.clip_image_size) image_embed, _ = self.clip.embed_image(image)
image_encoding = self.clip.visual_transformer(image) return image_embed
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if predict_x_start: if predict_x_start:
x_recon = pred x_recon = pred
else: else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred) x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not predict_x_start: if clip_denoised:
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False): def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start) model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
noise = noise_like(x.shape, device, repeat_noise) noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1): def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
@@ -1379,38 +1508,34 @@ class Decoder(BaseGaussianDiffusion):
torch.full((b,), i, device = device, dtype = torch.long), torch.full((b,), i, device = device, dtype = torch.long),
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
text_mask = text_mask,
cond_scale = cond_scale, cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start predict_x_start = predict_x_start,
clip_denoised = clip_denoised
) )
return img return img
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None): def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
x_recon = unet( pred = unet(
x_noisy, x_noisy,
t, times,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
text_mask = text_mask,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
cond_drop_prob = self.cond_drop_prob image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
) )
target = noise if not predict_x_start else x_start target = noise if not predict_x_start else x_start
if self.loss_type == 'l1': loss = self.loss_fn(pred, target)
loss = F.l1_loss(target, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(target, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(target, x_recon)
else:
raise NotImplementedError()
return loss return loss
@torch.no_grad() @torch.no_grad()
@@ -1418,9 +1543,9 @@ class Decoder(BaseGaussianDiffusion):
def sample(self, image_embed, text = None, cond_scale = 1.): def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0] batch_size = image_embed.shape[0]
text_encodings = None text_encodings = text_mask = None
if exists(text): if exists(text):
_, text_encodings = self.clip.embed_text(text) _, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
@@ -1438,6 +1563,7 @@ class Decoder(BaseGaussianDiffusion):
if unet.lowres_cond: if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size) lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
is_latent_diffusion = isinstance(vae, VQGanVAE)
image_size = vae.get_encoded_fmap_size(image_size) image_size = vae.get_encoded_fmap_size(image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size) shape = (batch_size, vae.encoded_dim, image_size, image_size)
@@ -1449,8 +1575,10 @@ class Decoder(BaseGaussianDiffusion):
shape, shape,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
text_mask = text_mask,
cond_scale = cond_scale, cond_scale = cond_scale,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img lowres_cond_img = lowres_cond_img
) )
@@ -1486,9 +1614,9 @@ class Decoder(BaseGaussianDiffusion):
if not exists(image_embed): if not exists(image_embed):
image_embed, _ = self.clip.embed_image(image) image_embed, _ = self.clip.embed_image(image)
text_encodings = None text_encodings = text_mask = None
if exists(text) and not exists(text_encodings): if exists(text) and not exists(text_encodings):
_, text_encodings = self.clip.embed_text(text) _, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
@@ -1503,7 +1631,7 @@ class Decoder(BaseGaussianDiffusion):
if exists(lowres_cond_img): if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img) lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start) return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
# main class # main class
@@ -1553,4 +1681,3 @@ class DALLE2(nn.Module):
return images[0] return images[0]
return images return images

View File

@@ -0,0 +1,29 @@
from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params):
no_wd_params = set([param for param in params if param.ndim < 2])
wd_params = set(params) - no_wd_params
return wd_params, no_wd_params
def get_optimizer(
params,
lr = 3e-4,
wd = 1e-2,
betas = (0.9, 0.999),
filter_by_requires_grad = False
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if wd == 0:
return Adam(params, lr = lr, betas = betas)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
param_groups = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)

View File

@@ -1,7 +1,43 @@
import copy import copy
from functools import partial
import torch import torch
from torch import nn from torch import nn
from dalle2_pytorch.dalle2_pytorch import Decoder
from dalle2_pytorch.optimizer import get_optimizer
# helper functions
def exists(val):
return val is not None
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# exponential moving average wrapper # exponential moving average wrapper
class EMA(nn.Module): class EMA(nn.Module):
@@ -9,16 +45,16 @@ class EMA(nn.Module):
self, self,
model, model,
beta = 0.99, beta = 0.99,
ema_update_after_step = 1000, update_after_step = 1000,
ema_update_every = 10, update_every = 10,
): ):
super().__init__() super().__init__()
self.beta = beta self.beta = beta
self.online_model = model self.online_model = model
self.ema_model = copy.deepcopy(model) self.ema_model = copy.deepcopy(model)
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0 self.update_after_step = update_after_step # only start EMA after this step number, starting at 0
self.ema_update_every = ema_update_every self.update_every = update_every
self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0.]))
@@ -26,7 +62,7 @@ class EMA(nn.Module):
def update(self): def update(self):
self.step += 1 self.step += 1
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0: if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
return return
if not self.initted: if not self.initted:
@@ -51,3 +87,71 @@ class EMA(nn.Module):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs) return self.ema_model(*args, **kwargs)
# trainers
class DecoderTrainer(nn.Module):
def __init__(
self,
decoder,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
**kwargs
):
super().__init__()
assert isinstance(decoder, Decoder)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.decoder = decoder
self.num_unets = len(self.decoder.unets)
self.use_ema = use_ema
if use_ema:
has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()])
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
self.ema_unets = nn.ModuleList([])
# be able to finely customize learning rate, weight decay
# per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
**kwargs
)
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
def update(self, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
unet = self.decoder.unets[index]
if exists(self.max_grad_norm):
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer = getattr(self, f'optim{index}')
optimizer.step()
optimizer.zero_grad()
if self.use_ema:
ema_unet = self.ema_unets[index]
ema_unet.update()
def forward(self, x, *, unet_number, **kwargs):
return self.decoder(x, unet_number = unet_number, **kwargs)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.63', version = '0.0.78',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',
@@ -23,6 +23,7 @@ setup(
], ],
install_requires=[ install_requires=[
'click', 'click',
'clip-anytorch',
'einops>=0.4', 'einops>=0.4',
'einops-exts>=0.0.3', 'einops-exts>=0.0.3',
'kornia>=0.5.4', 'kornia>=0.5.4',
@@ -31,7 +32,7 @@ setup(
'torchvision', 'torchvision',
'tqdm', 'tqdm',
'vector-quantize-pytorch', 'vector-quantize-pytorch',
'x-clip>=0.4.4', 'x-clip>=0.5.1',
'youtokentome' 'youtokentome'
], ],
classifiers=[ classifiers=[