Compare commits

...

7 Commits

Author SHA1 Message Date
Phil Wang
41fabf2922 fix a dtype conversion issue for the diffusion timesteps in the diffusion prior, thanks to @JiaHeng-DLUT 2022-10-19 09:26:06 -07:00
Heng Jia
5975e8222b Fix assert message (#253) 2022-10-18 08:50:59 -07:00
Phil Wang
c18c080128 fix for use with larger openai clip models by extracting dimension of last layernorm in clip 2022-09-29 09:09:47 -07:00
Phil Wang
b39653cf96 fix readme dataloader example 2022-09-20 08:39:52 -07:00
Phil Wang
39f8b6cf16 show example of using SOTA open sourced open clip 2022-09-19 10:45:20 -07:00
Phil Wang
d0c11b30b0 handle open clip adapter image size being a tuple 2022-09-19 10:27:14 -07:00
zion
86e2d5ba84 Minor Decoder Train Script Fixes (#242)
* ensure tokenized text is on proper device
* fix lpips mage distribution
2022-09-15 17:21:48 -07:00
4 changed files with 28 additions and 10 deletions

View File

@@ -634,10 +634,12 @@ Alternatively, you can also use <a href="https://github.com/mlfoundations/open_c
$ pip install open-clip-torch
```
Ex. using the <a href="https://laion.ai/blog/large-openclip/">SOTA Open Clip</a> model trained by <a href="https://github.com/rom1504">Romain</a>
```python
from dalle2_pytorch import OpenClipAdapter
clip = OpenClipAdapter()
clip = OpenClipAdapter('ViT-H/14')
```
Now you'll just have to worry about training the Prior and the Decoder!
@@ -1066,7 +1068,7 @@ dataloader = create_image_embedding_dataloader(
)
for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb.shape) # torch.Size([32, 512])
print(emb["img"].shape) # torch.Size([32, 512])
# Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually

View File

@@ -100,6 +100,9 @@ def eval_decorator(fn):
return out
return inner
def is_float_dtype(dtype):
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
def is_list_str(x):
if not isinstance(x, (list, tuple)):
return False
@@ -314,7 +317,10 @@ class OpenAIClipAdapter(BaseClipAdapter):
self.eos_id = 49407 # for handling 0 being also '!'
text_attention_final = self.find_layer('ln_final')
self.dim_latent_ = text_attention_final.weight.shape[0]
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
@@ -333,7 +339,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
@property
def dim_latent(self):
return 512
return self.dim_latent_
@property
def image_size(self):
@@ -406,7 +412,10 @@ class OpenClipAdapter(BaseClipAdapter):
@property
def image_size(self):
return self.clip.visual.image_size
image_size = self.clip.visual.image_size
if isinstance(image_size, tuple):
return max(image_size)
return image_size
@property
def image_channels(self):
@@ -962,6 +971,8 @@ class DiffusionPriorNetwork(nn.Module):
Rearrange('b (n d) -> b n d', n = num_text_embeds)
)
self.continuous_embedded_time = not exists(num_timesteps)
self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds)
@@ -1089,6 +1100,9 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if self.continuous_embedded_time:
diffusion_timesteps = diffusion_timesteps.type(dtype)
time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
@@ -1426,7 +1440,7 @@ class DiffusionPrior(nn.Module):
**kwargs
):
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
if exists(image):
@@ -1532,6 +1546,8 @@ class SinusoidalPosEmb(nn.Module):
def forward(self, x):
dtype, device = x.dtype, x.device
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)

View File

@@ -1 +1 @@
__version__ = '1.10.5'
__version__ = '1.10.8'

View File

@@ -156,7 +156,7 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
if text_embeddings[0] is None:
# Generate text embeddings from text
assert clip is not None, "clip is None, but text_embeddings is None"
tokenized_texts = tokenize(txts, truncate=True)
tokenized_texts = tokenize(txts, truncate=True).to(device=device)
text_embed, text_encodings = clip.embed_text(tokenized_texts)
sample_params["text_encodings"] = text_encodings
else:
@@ -229,8 +229,8 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=Non
metrics["KID_std"] = kid_std.item()
if exists(LPIPS):
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images)
@@ -480,7 +480,7 @@ def train(
else:
# Then we need to pass the text instead
assert clip is not None
tokenized_texts = tokenize(txt, truncate=True)
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
text_embed, text_encodings = clip.embed_text(tokenized_texts)
forward_params['text_encodings'] = text_encodings