Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
39d3659ad9 now completely OpenAI CLIP compatible for training 2022-04-29 11:26:24 -07:00
4 changed files with 7 additions and 41 deletions

View File

@@ -430,8 +430,8 @@ images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings # precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone # here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed clip_image_embeds = diffusion_prior.get_image_embed(images)
clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
# feed text and images into diffusion prior network # feed text and images into diffusion prior network
@@ -741,7 +741,6 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder - [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference

View File

@@ -3,7 +3,6 @@ from tqdm import tqdm
from inspect import isfunction from inspect import isfunction
from functools import partial from functools import partial
from contextlib import contextmanager from contextlib import contextmanager
from collections import namedtuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -103,9 +102,6 @@ def unnormalize_img(normed_img):
# clip related adapters # clip related adapters
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
class BaseClipAdapter(nn.Module): class BaseClipAdapter(nn.Module):
def __init__(self, clip): def __init__(self, clip):
super().__init__() super().__init__()
@@ -157,7 +153,7 @@ class XClipAdapter(BaseClipAdapter):
encoder_output = self.clip.text_transformer(text) encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:] text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls) text_embed = self.clip.to_text_latent(text_cls)
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask) return l2norm(text_embed), text_encodings, text_mask
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
@@ -165,7 +161,7 @@ class XClipAdapter(BaseClipAdapter):
encoder_output = self.clip.visual_transformer(image) encoder_output = self.clip.visual_transformer(image)
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:] image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
image_embed = self.clip.to_visual_latent(image_cls) image_embed = self.clip.to_visual_latent(image_cls)
return EmbeddedImage(l2norm(image_embed), image_encodings) return l2norm(image_embed), image_encodings
class OpenAIClipAdapter(BaseClipAdapter): class OpenAIClipAdapter(BaseClipAdapter):
def __init__( def __init__(
@@ -223,7 +219,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_embed = self.clip.encode_text(text) text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings text_encodings = self.text_encodings
del self.text_encodings del self.text_encodings
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask) return text_embed.float(), text_encodings.float(), text_mask
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
@@ -231,7 +227,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
image = resize_image_to(image, self.image_size) image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image)) image = self.clip_normalize(unnormalize_img(image))
image_embed = self.clip.encode_image(image) image_embed = self.clip.encode_image(image)
return EmbeddedImage(image_embed.float(), None) return image_embed.float(), None
# classifier free guidance functions # classifier free guidance functions

View File

@@ -1,29 +0,0 @@
from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params):
no_wd_params = set([param for param in params if param.ndim < 2])
wd_params = set(params) - no_wd_params
return wd_params, no_wd_params
def get_optimizer(
params,
lr = 3e-4,
wd = 1e-2,
betas = (0.9, 0.999),
filter_by_requires_grad = False
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if wd == 0:
return Adam(params, lr = lr, betas = betas)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
param_groups = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.71', version = '0.0.67',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',