mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41fabf2922 | ||
|
|
5975e8222b | ||
|
|
c18c080128 | ||
|
|
b39653cf96 | ||
|
|
39f8b6cf16 | ||
|
|
d0c11b30b0 | ||
|
|
86e2d5ba84 | ||
|
|
0d82dff9c5 | ||
|
|
8bbc956ff1 | ||
|
|
22019fddeb | ||
|
|
6fb7e91343 | ||
|
|
ba58ae0bf2 |
@@ -634,10 +634,12 @@ Alternatively, you can also use <a href="https://github.com/mlfoundations/open_c
|
||||
$ pip install open-clip-torch
|
||||
```
|
||||
|
||||
Ex. using the <a href="https://laion.ai/blog/large-openclip/">SOTA Open Clip</a> model trained by <a href="https://github.com/rom1504">Romain</a>
|
||||
|
||||
```python
|
||||
from dalle2_pytorch import OpenClipAdapter
|
||||
|
||||
clip = OpenClipAdapter()
|
||||
clip = OpenClipAdapter('ViT-H/14')
|
||||
```
|
||||
|
||||
Now you'll just have to worry about training the Prior and the Decoder!
|
||||
@@ -1066,7 +1068,7 @@ dataloader = create_image_embedding_dataloader(
|
||||
)
|
||||
for img, emb in dataloader:
|
||||
print(img.shape) # torch.Size([32, 3, 256, 256])
|
||||
print(emb.shape) # torch.Size([32, 512])
|
||||
print(emb["img"].shape) # torch.Size([32, 512])
|
||||
# Train decoder only as shown above
|
||||
|
||||
# Or create a dataset without a loader so you can configure it manually
|
||||
@@ -1126,6 +1128,7 @@ For detailed information on training the diffusion prior, please refer to the [d
|
||||
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
|
||||
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
|
||||
- [ ] add simple outpainting, text-guided 2x size the image for starters
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -100,6 +100,9 @@ def eval_decorator(fn):
|
||||
return out
|
||||
return inner
|
||||
|
||||
def is_float_dtype(dtype):
|
||||
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
|
||||
|
||||
def is_list_str(x):
|
||||
if not isinstance(x, (list, tuple)):
|
||||
return False
|
||||
@@ -314,7 +317,10 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
self.eos_id = 49407 # for handling 0 being also '!'
|
||||
|
||||
text_attention_final = self.find_layer('ln_final')
|
||||
|
||||
self.dim_latent_ = text_attention_final.weight.shape[0]
|
||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||
|
||||
self.clip_normalize = preprocess.transforms[-1]
|
||||
self.cleared = False
|
||||
|
||||
@@ -333,7 +339,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return 512
|
||||
return self.dim_latent_
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
@@ -406,7 +412,10 @@ class OpenClipAdapter(BaseClipAdapter):
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
return self.clip.visual.image_size
|
||||
image_size = self.clip.visual.image_size
|
||||
if isinstance(image_size, tuple):
|
||||
return max(image_size)
|
||||
return image_size
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
@@ -962,6 +971,8 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
Rearrange('b (n d) -> b n d', n = num_text_embeds)
|
||||
)
|
||||
|
||||
self.continuous_embedded_time = not exists(num_timesteps)
|
||||
|
||||
self.to_time_embeds = nn.Sequential(
|
||||
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
Rearrange('b (n d) -> b n d', n = num_time_embeds)
|
||||
@@ -1070,7 +1081,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
|
||||
|
||||
text_embeds = torch.where(
|
||||
text_embed = torch.where(
|
||||
text_keep_mask,
|
||||
text_embed,
|
||||
null_text_embeds
|
||||
@@ -1089,6 +1100,9 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
if self.continuous_embedded_time:
|
||||
diffusion_timesteps = diffusion_timesteps.type(dtype)
|
||||
|
||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
@@ -1166,6 +1180,10 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
self.net = net
|
||||
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
|
||||
|
||||
assert net.dim == self.image_embed_dim, f'your diffusion prior network has a dimension of {net.dim}, but you set your image embedding dimension (keyword image_embed_dim) on DiffusionPrior to {self.image_embed_dim}'
|
||||
assert not exists(clip) or clip.dim_latent == self.image_embed_dim, f'you passed in a CLIP to the diffusion prior with latent dimensions of {clip.dim_latent}, but your image embedding dimension (keyword image_embed_dim) for the DiffusionPrior was set to {self.image_embed_dim}'
|
||||
|
||||
self.channels = default(image_channels, lambda: clip.image_channels)
|
||||
|
||||
self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob)
|
||||
@@ -1255,7 +1273,7 @@ class DiffusionPrior(nn.Module):
|
||||
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
|
||||
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
@@ -1277,12 +1295,14 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
|
||||
|
||||
# derive x0
|
||||
|
||||
if self.predict_x_start:
|
||||
x_start = pred
|
||||
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
|
||||
else:
|
||||
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
|
||||
pred_noise = pred
|
||||
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise)
|
||||
|
||||
# clip x0 before maybe predicting noise
|
||||
|
||||
if not self.predict_x_start:
|
||||
x_start.clamp_(-1., 1.)
|
||||
@@ -1290,6 +1310,17 @@ class DiffusionPrior(nn.Module):
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_start = self.l2norm_clamp_embed(x_start)
|
||||
|
||||
# predict noise
|
||||
|
||||
if self.predict_x_start:
|
||||
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
|
||||
else:
|
||||
pred_noise = pred
|
||||
|
||||
if time_next < 0:
|
||||
image_embed = x_start
|
||||
continue
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
|
||||
@@ -1409,7 +1440,7 @@ class DiffusionPrior(nn.Module):
|
||||
**kwargs
|
||||
):
|
||||
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
|
||||
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
|
||||
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
|
||||
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
||||
|
||||
if exists(image):
|
||||
@@ -1515,6 +1546,8 @@ class SinusoidalPosEmb(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
dtype, device = x.dtype, x.device
|
||||
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
|
||||
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
||||
@@ -2841,12 +2874,13 @@ class Decoder(nn.Module):
|
||||
inpaint_mask = None,
|
||||
inpaint_resample_times = 5
|
||||
):
|
||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod, self.ddim_sampling_eta
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
time_pairs = list(filter(lambda t: t[0] > t[1], time_pairs))
|
||||
|
||||
is_inpaint = exists(inpaint_image)
|
||||
resample_times = inpaint_resample_times if is_inpaint else 1
|
||||
@@ -2888,16 +2922,25 @@ class Decoder(nn.Module):
|
||||
|
||||
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||
|
||||
# predict x0
|
||||
|
||||
if predict_x_start:
|
||||
x_start = pred
|
||||
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
||||
else:
|
||||
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
||||
pred_noise = pred
|
||||
|
||||
# maybe clip x0
|
||||
|
||||
if clip_denoised:
|
||||
x_start = self.dynamic_threshold(x_start)
|
||||
|
||||
# predict noise
|
||||
|
||||
if predict_x_start:
|
||||
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
||||
else:
|
||||
pred_noise = pred
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
noise = torch.randn_like(img) if not is_last_timestep else 0.
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.10.0'
|
||||
__version__ = '1.10.8'
|
||||
|
||||
@@ -156,7 +156,7 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
|
||||
if text_embeddings[0] is None:
|
||||
# Generate text embeddings from text
|
||||
assert clip is not None, "clip is None, but text_embeddings is None"
|
||||
tokenized_texts = tokenize(txts, truncate=True)
|
||||
tokenized_texts = tokenize(txts, truncate=True).to(device=device)
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
sample_params["text_encodings"] = text_encodings
|
||||
else:
|
||||
@@ -229,8 +229,8 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=Non
|
||||
metrics["KID_std"] = kid_std.item()
|
||||
if exists(LPIPS):
|
||||
# Convert from [0, 1] to [-1, 1]
|
||||
renorm_real_images = real_images.mul(2).sub(1)
|
||||
renorm_generated_images = generated_images.mul(2).sub(1)
|
||||
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
|
||||
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
|
||||
lpips.to(device=device)
|
||||
lpips.update(renorm_real_images, renorm_generated_images)
|
||||
@@ -480,7 +480,7 @@ def train(
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
assert clip is not None
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
forward_params['text_encodings'] = text_encodings
|
||||
|
||||
Reference in New Issue
Block a user