Compare commits

...

10 Commits

Author SHA1 Message Date
Phil Wang
c18c080128 fix for use with larger openai clip models by extracting dimension of last layernorm in clip 2022-09-29 09:09:47 -07:00
Phil Wang
b39653cf96 fix readme dataloader example 2022-09-20 08:39:52 -07:00
Phil Wang
39f8b6cf16 show example of using SOTA open sourced open clip 2022-09-19 10:45:20 -07:00
Phil Wang
d0c11b30b0 handle open clip adapter image size being a tuple 2022-09-19 10:27:14 -07:00
zion
86e2d5ba84 Minor Decoder Train Script Fixes (#242)
* ensure tokenized text is on proper device
* fix lpips mage distribution
2022-09-15 17:21:48 -07:00
Phil Wang
0d82dff9c5 in ddim, noise should be predicted after x0 is maybe clipped, thanks to @lukovnikov for pointing this out in another repository 2022-09-01 09:40:47 -07:00
Phil Wang
8bbc956ff1 fix bug with misnamed variable in diffusion prior network 2022-08-31 17:19:05 -07:00
Phil Wang
22019fddeb todo 2022-08-31 13:36:05 -07:00
Phil Wang
6fb7e91343 fix ddim to use alpha_cumprod 2022-08-31 07:40:46 -07:00
Phil Wang
ba58ae0bf2 add two asserts to diffusion prior to ensure matching image embedding dimensions for clip, diffusion prior network, and what was set on diffusion prior 2022-08-28 10:11:37 -07:00
4 changed files with 53 additions and 17 deletions

View File

@@ -634,10 +634,12 @@ Alternatively, you can also use <a href="https://github.com/mlfoundations/open_c
$ pip install open-clip-torch
```
Ex. using the <a href="https://laion.ai/blog/large-openclip/">SOTA Open Clip</a> model trained by <a href="https://github.com/rom1504">Romain</a>
```python
from dalle2_pytorch import OpenClipAdapter
clip = OpenClipAdapter()
clip = OpenClipAdapter('ViT-H/14')
```
Now you'll just have to worry about training the Prior and the Decoder!
@@ -1066,7 +1068,7 @@ dataloader = create_image_embedding_dataloader(
)
for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb.shape) # torch.Size([32, 512])
print(emb["img"].shape) # torch.Size([32, 512])
# Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually
@@ -1126,6 +1128,7 @@ For detailed information on training the diffusion prior, please refer to the [d
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
- [ ] add simple outpainting, text-guided 2x size the image for starters
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
## Citations

View File

@@ -314,7 +314,10 @@ class OpenAIClipAdapter(BaseClipAdapter):
self.eos_id = 49407 # for handling 0 being also '!'
text_attention_final = self.find_layer('ln_final')
self.dim_latent_ = text_attention_final.weight.shape[0]
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
@@ -333,7 +336,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
@property
def dim_latent(self):
return 512
return self.dim_latent_
@property
def image_size(self):
@@ -406,7 +409,10 @@ class OpenClipAdapter(BaseClipAdapter):
@property
def image_size(self):
return self.clip.visual.image_size
image_size = self.clip.visual.image_size
if isinstance(image_size, tuple):
return max(image_size)
return image_size
@property
def image_channels(self):
@@ -1070,7 +1076,7 @@ class DiffusionPriorNetwork(nn.Module):
null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
text_embeds = torch.where(
text_embed = torch.where(
text_keep_mask,
text_embed,
null_text_embeds
@@ -1166,6 +1172,10 @@ class DiffusionPrior(nn.Module):
self.net = net
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
assert net.dim == self.image_embed_dim, f'your diffusion prior network has a dimension of {net.dim}, but you set your image embedding dimension (keyword image_embed_dim) on DiffusionPrior to {self.image_embed_dim}'
assert not exists(clip) or clip.dim_latent == self.image_embed_dim, f'you passed in a CLIP to the diffusion prior with latent dimensions of {clip.dim_latent}, but your image embedding dimension (keyword image_embed_dim) for the DiffusionPrior was set to {self.image_embed_dim}'
self.channels = default(image_channels, lambda: clip.image_channels)
self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob)
@@ -1255,7 +1265,7 @@ class DiffusionPrior(nn.Module):
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
@@ -1277,12 +1287,14 @@ class DiffusionPrior(nn.Module):
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
# derive x0
if self.predict_x_start:
x_start = pred
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
pred_noise = pred
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise)
# clip x0 before maybe predicting noise
if not self.predict_x_start:
x_start.clamp_(-1., 1.)
@@ -1290,6 +1302,17 @@ class DiffusionPrior(nn.Module):
if self.predict_x_start and self.sampling_clamp_l2norm:
x_start = self.l2norm_clamp_embed(x_start)
# predict noise
if self.predict_x_start:
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
else:
pred_noise = pred
if time_next < 0:
image_embed = x_start
continue
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
@@ -2841,12 +2864,13 @@ class Decoder(nn.Module):
inpaint_mask = None,
inpaint_resample_times = 5
):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
time_pairs = list(filter(lambda t: t[0] > t[1], time_pairs))
is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1
@@ -2888,16 +2912,25 @@ class Decoder(nn.Module):
pred, _ = self.parse_unet_output(learned_variance, unet_output)
# predict x0
if predict_x_start:
x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
# maybe clip x0
if clip_denoised:
x_start = self.dynamic_threshold(x_start)
# predict noise
if predict_x_start:
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
pred_noise = pred
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if not is_last_timestep else 0.

View File

@@ -1 +1 @@
__version__ = '1.10.0'
__version__ = '1.10.7'

View File

@@ -156,7 +156,7 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
if text_embeddings[0] is None:
# Generate text embeddings from text
assert clip is not None, "clip is None, but text_embeddings is None"
tokenized_texts = tokenize(txts, truncate=True)
tokenized_texts = tokenize(txts, truncate=True).to(device=device)
text_embed, text_encodings = clip.embed_text(tokenized_texts)
sample_params["text_encodings"] = text_encodings
else:
@@ -229,8 +229,8 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=Non
metrics["KID_std"] = kid_std.item()
if exists(LPIPS):
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images)
@@ -480,7 +480,7 @@ def train(
else:
# Then we need to pass the text instead
assert clip is not None
tokenized_texts = tokenize(txt, truncate=True)
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
text_embed, text_encodings = clip.embed_text(tokenized_texts)
forward_params['text_encodings'] = text_encodings