Compare commits

..

3 Commits

Author SHA1 Message Date
lucidrains
1e173f4c66 more fixes to config 2023-10-18 20:27:32 -07:00
lucidrains
410a6144e1 new einops is torch compile friendly 2023-10-18 15:45:09 -07:00
lucidrains
c6c3882dc1 fix all optional types in train config 2023-10-07 11:34:34 -07:00
5 changed files with 23 additions and 30 deletions

View File

@@ -9,7 +9,7 @@
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 16,
"attn_heads": 4,
"self_attn": [false, true, true, true]
"self_attn": [false, true, true, true]
}
],
"clip": {

View File

@@ -1,10 +1,3 @@
import torch
from packaging import version
if version.parse(torch.__version__) >= version.parse('2.0.0'):
from einops._torch_specific import allow_ops_in_compiled_graph
allow_ops_in_compiled_graph()
from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter

View File

@@ -115,7 +115,7 @@ class TrackerConfig(BaseModel):
class AdapterConfig(BaseModel):
make: str = "openai"
model: str = "ViT-L/14"
base_model_kwargs: Dict[str, Any] = None
base_model_kwargs: Optional[Dict[str, Any]] = None
def create(self):
if self.make == "openai":
@@ -134,8 +134,8 @@ class AdapterConfig(BaseModel):
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
max_text_len: int = None
num_timesteps: int = None
max_text_len: Optional[int] = None
num_timesteps: Optional[int] = None
num_time_embeds: int = 1
num_image_embeds: int = 1
num_text_embeds: int = 1
@@ -158,7 +158,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
return DiffusionPriorNetwork(**kwargs)
class DiffusionPriorConfig(BaseModel):
clip: AdapterConfig = None
clip: Optional[AdapterConfig] = None
net: DiffusionPriorNetworkConfig
image_embed_dim: int
image_size: int
@@ -195,7 +195,7 @@ class DiffusionPriorTrainConfig(BaseModel):
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
warmup_steps: int = None # number of warmup steps
warmup_steps: Optional[int] = None # number of warmup steps
save_every_seconds: int = 3600 # how often to save
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
best_validation_loss: float = 1e9 # the current best valudation loss observed
@@ -228,12 +228,12 @@ class TrainDiffusionPriorConfig(BaseModel):
class UnetConfig(BaseModel):
dim: int
dim_mults: ListOrTuple[int]
image_embed_dim: int = None
text_embed_dim: int = None
cond_on_text_encodings: bool = None
cond_dim: int = None
image_embed_dim: Optional[int] = None
text_embed_dim: Optional[int] = None
cond_on_text_encodings: Optional[bool] = None
cond_dim: Optional[int] = None
channels: int = 3
self_attn: ListOrTuple[int]
self_attn: ListOrTuple[bool]
attn_dim_head: int = 32
attn_heads: int = 16
init_cross_embed: bool = True
@@ -243,14 +243,14 @@ class UnetConfig(BaseModel):
class DecoderConfig(BaseModel):
unets: ListOrTuple[UnetConfig]
image_size: int = None
image_size: Optional[int] = None
image_sizes: ListOrTuple[int] = None
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
clip: Optional[AdapterConfig] = None # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
loss_type: str = 'l2'
beta_schedule: ListOrTuple[str] = None # None means all cosine
beta_schedule: Optional[ListOrTuple[str]] = None # None means all cosine
learned_variance: SingularOrIterable[bool] = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5
@@ -320,20 +320,20 @@ class DecoderTrainConfig(BaseModel):
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
cond_scale: Union[float, List[float]] = 1.0
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
epoch_samples: Optional[int] = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: Optional[int] = None # Same as above but for validation.
save_immediately: bool = False
use_ema: bool = True
ema_beta: float = 0.999
amp: bool = False
unet_training_mask: ListOrTuple[bool] = None # If None, use all unets
unet_training_mask: Optional[ListOrTuple[bool]] = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
FID: Dict[str, Any] = None
IS: Dict[str, Any] = None
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
FID: Optional[Dict[str, Any]] = None
IS: Optional[Dict[str, Any]] = None
KID: Optional[Dict[str, Any]] = None
LPIPS: Optional[Dict[str, Any]] = None
class TrainDecoderConfig(BaseModel):
decoder: DecoderConfig

View File

@@ -1 +1 @@
__version__ = '1.15.2'
__version__ = '1.15.5'

View File

@@ -30,7 +30,7 @@ setup(
'clip-anytorch>=2.5.2',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7',
'einops>=0.6.1',
'einops>=0.7.0',
'embedding-reader',
'kornia>=0.5.4',
'numpy',