mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-22 20:44:25 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a389f81138 | ||
|
|
0283556608 | ||
|
|
5063d192b6 |
@@ -430,8 +430,8 @@ images = torch.randn(4, 3, 256, 256).cuda()
|
|||||||
# precompute the text and image embeddings
|
# precompute the text and image embeddings
|
||||||
# here using the diffusion prior class, but could be done with CLIP alone
|
# here using the diffusion prior class, but could be done with CLIP alone
|
||||||
|
|
||||||
clip_image_embeds = diffusion_prior.get_image_embed(images)
|
clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed
|
||||||
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
|
clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed
|
||||||
|
|
||||||
# feed text and images into diffusion prior network
|
# feed text and images into diffusion prior network
|
||||||
|
|
||||||
@@ -741,6 +741,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
|
- [ ] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from tqdm import tqdm
|
|||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -102,6 +103,9 @@ def unnormalize_img(normed_img):
|
|||||||
|
|
||||||
# clip related adapters
|
# clip related adapters
|
||||||
|
|
||||||
|
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
|
||||||
|
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
||||||
|
|
||||||
class BaseClipAdapter(nn.Module):
|
class BaseClipAdapter(nn.Module):
|
||||||
def __init__(self, clip):
|
def __init__(self, clip):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -153,7 +157,7 @@ class XClipAdapter(BaseClipAdapter):
|
|||||||
encoder_output = self.clip.text_transformer(text)
|
encoder_output = self.clip.text_transformer(text)
|
||||||
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||||
text_embed = self.clip.to_text_latent(text_cls)
|
text_embed = self.clip.to_text_latent(text_cls)
|
||||||
return l2norm(text_embed), text_encodings, text_mask
|
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def embed_image(self, image):
|
def embed_image(self, image):
|
||||||
@@ -161,7 +165,7 @@ class XClipAdapter(BaseClipAdapter):
|
|||||||
encoder_output = self.clip.visual_transformer(image)
|
encoder_output = self.clip.visual_transformer(image)
|
||||||
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||||
image_embed = self.clip.to_visual_latent(image_cls)
|
image_embed = self.clip.to_visual_latent(image_cls)
|
||||||
return l2norm(image_embed), image_encodings
|
return EmbeddedImage(l2norm(image_embed), image_encodings)
|
||||||
|
|
||||||
class OpenAIClipAdapter(BaseClipAdapter):
|
class OpenAIClipAdapter(BaseClipAdapter):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -219,7 +223,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
text_embed = self.clip.encode_text(text)
|
text_embed = self.clip.encode_text(text)
|
||||||
text_encodings = self.text_encodings
|
text_encodings = self.text_encodings
|
||||||
del self.text_encodings
|
del self.text_encodings
|
||||||
return text_embed.float(), text_encodings.float(), text_mask
|
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def embed_image(self, image):
|
def embed_image(self, image):
|
||||||
@@ -227,7 +231,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
image = resize_image_to(image, self.image_size)
|
image = resize_image_to(image, self.image_size)
|
||||||
image = self.clip_normalize(unnormalize_img(image))
|
image = self.clip_normalize(unnormalize_img(image))
|
||||||
image_embed = self.clip.encode_image(image)
|
image_embed = self.clip.encode_image(image)
|
||||||
return image_embed.float(), None
|
return EmbeddedImage(image_embed.float(), None)
|
||||||
|
|
||||||
# classifier free guidance functions
|
# classifier free guidance functions
|
||||||
|
|
||||||
|
|||||||
29
dalle2_pytorch/optimizer.py
Normal file
29
dalle2_pytorch/optimizer.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from torch.optim import AdamW, Adam
|
||||||
|
|
||||||
|
def separate_weight_decayable_params(params):
|
||||||
|
no_wd_params = set([param for param in params if param.ndim < 2])
|
||||||
|
wd_params = set(params) - no_wd_params
|
||||||
|
return wd_params, no_wd_params
|
||||||
|
|
||||||
|
def get_optimizer(
|
||||||
|
params,
|
||||||
|
lr = 3e-4,
|
||||||
|
wd = 1e-2,
|
||||||
|
betas = (0.9, 0.999),
|
||||||
|
filter_by_requires_grad = False
|
||||||
|
):
|
||||||
|
if filter_by_requires_grad:
|
||||||
|
params = list(filter(lambda t: t.requires_grad, params))
|
||||||
|
|
||||||
|
if wd == 0:
|
||||||
|
return Adam(params, lr = lr, betas = betas)
|
||||||
|
|
||||||
|
params = set(params)
|
||||||
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||||
|
|
||||||
|
param_groups = [
|
||||||
|
{'params': list(wd_params)},
|
||||||
|
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||||
|
]
|
||||||
|
|
||||||
|
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
|
||||||
Reference in New Issue
Block a user