Compare commits

..

8 Commits

Author SHA1 Message Date
Phil Wang
cc58f75474 bump to newer package of clip-anytorch that allows for text encodings < maximum context length 2023-03-04 09:37:25 -08:00
Phil Wang
3b2cf7b0bc fix for self conditioning in diffusion prior network https://github.com/lucidrains/DALLE2-pytorch/issues/273 2023-02-11 17:18:40 -08:00
Phil Wang
984d62a373 default ddim sampling eta to 0 2022-12-23 13:23:09 -08:00
Phil Wang
683dd98b96 extra insurance in case eos id is not there 2022-12-15 10:54:21 -08:00
Phil Wang
067ac323da address https://github.com/lucidrains/DALLE2-pytorch/issues/266 2022-11-23 08:41:25 -08:00
zion
91c8d1ca13 bug fix cosine annealing optimizer in prior trainer (#262) 2022-11-11 12:15:13 -08:00
zion
08238a7200 depend on open-clip-torch (#261)
fix the previous commit which assumes open_clip is installed
2022-11-07 16:19:08 -08:00
zion
7166ad6711 add open clip to train_config (#260)
add the ability to use open_clip in the train configs (useful for the new SOTA h/14 model)
2022-11-07 15:44:36 -08:00
5 changed files with 15 additions and 6 deletions

View File

@@ -360,6 +360,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared
text_embed = self.clip.encode_text(text)
@@ -434,6 +435,7 @@ class OpenClipAdapter(BaseClipAdapter):
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared
text_embed = self.clip.encode_text(text)
@@ -1122,7 +1124,7 @@ class DiffusionPriorNetwork(nn.Module):
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond:
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
tokens = torch.cat((
text_encodings,
@@ -1320,7 +1322,7 @@ class DiffusionPrior(nn.Module):
elif self.predict_x_start:
x_start = pred
else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise)
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
# clip x0 before maybe predicting noise
@@ -2494,7 +2496,7 @@ class Decoder(nn.Module):
dynamic_thres_percentile = 0.95,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1,
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
ddim_sampling_eta = 0. # can be set to 0. for deterministic sampling afaict
):
super().__init__()

View File

@@ -4,11 +4,13 @@ from pydantic import BaseModel, validator, root_validator
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
from x_clip import CLIP as XCLIP
from open_clip import list_pretrained
from coca_pytorch import CoCa
from dalle2_pytorch.dalle2_pytorch import (
CoCaAdapter,
OpenAIClipAdapter,
OpenClipAdapter,
Unet,
Decoder,
DiffusionPrior,
@@ -117,6 +119,10 @@ class AdapterConfig(BaseModel):
def create(self):
if self.make == "openai":
return OpenAIClipAdapter(self.model)
elif self.make == "open_clip":
pretrained = dict(list_pretrained())
checkpoint = pretrained[self.model]
return OpenClipAdapter(name=self.model, pretrained=checkpoint)
elif self.make == "x-clip":
return XClipAdapter(XCLIP(**self.base_model_kwargs))
elif self.make == "coca":

View File

@@ -236,7 +236,7 @@ class DiffusionPriorTrainer(nn.Module):
)
if exists(cosine_decay_max_steps):
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)

View File

@@ -1 +1 @@
__version__ = '1.11.1'
__version__ = '1.12.2'

View File

@@ -26,7 +26,8 @@ setup(
install_requires=[
'accelerate',
'click',
'clip-anytorch>=2.4.0',
'open-clip-torch>=2.0.0,<3.0.0',
'clip-anytorch>=2.5.2',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7',
'einops>=0.4',