Compare commits

..

1 Commits
main ... 1.4.1

10 changed files with 216 additions and 711 deletions

View File

@@ -49,7 +49,6 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice - <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship - <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library - <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
- <a href="https://github.com/arogozhnikov">Alex</a> for <a href="https://github.com/arogozhnikov/einops">einops</a>, indispensable tool for tensor manipulation
... and many others. Thank you! 🙏 ... and many others. Thank you! 🙏
@@ -628,20 +627,6 @@ images = dalle2(
# save your image (in this example, of size 256x256) # save your image (in this example, of size 256x256)
``` ```
Alternatively, you can also use <a href="https://github.com/mlfoundations/open_clip">Open Clip</a>
```bash
$ pip install open-clip-torch
```
Ex. using the <a href="https://laion.ai/blog/large-openclip/">SOTA Open Clip</a> model trained by <a href="https://github.com/rom1504">Romain</a>
```python
from dalle2_pytorch import OpenClipAdapter
clip = OpenClipAdapter('ViT-H/14')
```
Now you'll just have to worry about training the Prior and the Decoder! Now you'll just have to worry about training the Prior and the Decoder!
## Inpainting ## Inpainting
@@ -1068,7 +1053,7 @@ dataloader = create_image_embedding_dataloader(
) )
for img, emb in dataloader: for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256]) print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb["img"].shape) # torch.Size([32, 512]) print(emb.shape) # torch.Size([32, 512])
# Train decoder only as shown above # Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually # Or create a dataset without a loader so you can configure it manually
@@ -1128,7 +1113,6 @@ For detailed information on training the diffusion prior, please refer to the [d
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865 - [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments - [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364 - [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
- [ ] add simple outpainting, text-guided 2x size the image for starters
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
## Citations ## Citations
@@ -1257,55 +1241,4 @@ For detailed information on training the diffusion prior, please refer to the [d
} }
``` ```
```bibtex
@misc{chen2022analog,
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
year = {2022},
eprint = {2208.04202},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@article{Qiao2019WeightS,
title = {Weight Standardization},
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
journal = {ArXiv},
year = {2019},
volume = {abs/1903.10520}
}
```
```bibtex
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
```
```bibtex
@article{Sunkara2022NoMS,
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
author = {Raja Sunkara and Tie Luo},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.03641}
}
```
```bibtex
@article{Salimans2022ProgressiveDF,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.00512}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -9,7 +9,7 @@
"dim_mults": [1, 2, 4, 8], "dim_mults": [1, 2, 4, 8],
"attn_dim_head": 16, "attn_dim_head": 16,
"attn_heads": 4, "attn_heads": 4,
"self_attn": [false, true, true, true] "self_attn": [false, true, true, true]
} }
], ],
"clip": { "clip": {

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.version import __version__ from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE from dalle2_pytorch.vqgan_vae import VQGanVAE

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,14 @@
import json import json
from torchvision import transforms as T from torchvision import transforms as T
from pydantic import BaseModel, validator, model_validator from pydantic import BaseModel, validator, root_validator
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
from x_clip import CLIP as XCLIP from x_clip import CLIP as XCLIP
from open_clip import list_pretrained
from coca_pytorch import CoCa from coca_pytorch import CoCa
from dalle2_pytorch.dalle2_pytorch import ( from dalle2_pytorch.dalle2_pytorch import (
CoCaAdapter, CoCaAdapter,
OpenAIClipAdapter, OpenAIClipAdapter,
OpenClipAdapter,
Unet, Unet,
Decoder, Decoder,
DiffusionPrior, DiffusionPrior,
@@ -38,12 +36,12 @@ class TrainSplitConfig(BaseModel):
val: float = 0.15 val: float = 0.15
test: float = 0.1 test: float = 0.1
@model_validator(mode = 'after') @root_validator
def validate_all(self, m): def validate_all(cls, fields):
actual_sum = sum([*dict(self).values()]) actual_sum = sum([*fields.values()])
if actual_sum != 1.: if actual_sum != 1.:
raise ValueError(f'{dict(self).keys()} must sum to 1.0. Found: {actual_sum}') raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
return self return fields
class TrackerLogConfig(BaseModel): class TrackerLogConfig(BaseModel):
log_type: str = 'console' log_type: str = 'console'
@@ -59,7 +57,6 @@ class TrackerLogConfig(BaseModel):
kwargs = self.dict() kwargs = self.dict()
return create_logger(self.log_type, data_path, **kwargs) return create_logger(self.log_type, data_path, **kwargs)
class TrackerLoadConfig(BaseModel): class TrackerLoadConfig(BaseModel):
load_from: Optional[str] = None load_from: Optional[str] = None
only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
@@ -90,7 +87,7 @@ class TrackerConfig(BaseModel):
data_path: str = '.tracker_data' data_path: str = '.tracker_data'
overwrite_data_path: bool = False overwrite_data_path: bool = False
log: TrackerLogConfig log: TrackerLogConfig
load: Optional[TrackerLoadConfig] = None load: Optional[TrackerLoadConfig]
save: Union[List[TrackerSaveConfig], TrackerSaveConfig] save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker: def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
@@ -115,15 +112,11 @@ class TrackerConfig(BaseModel):
class AdapterConfig(BaseModel): class AdapterConfig(BaseModel):
make: str = "openai" make: str = "openai"
model: str = "ViT-L/14" model: str = "ViT-L/14"
base_model_kwargs: Optional[Dict[str, Any]] = None base_model_kwargs: Dict[str, Any] = None
def create(self): def create(self):
if self.make == "openai": if self.make == "openai":
return OpenAIClipAdapter(self.model) return OpenAIClipAdapter(self.model)
elif self.make == "open_clip":
pretrained = dict(list_pretrained())
checkpoint = pretrained[self.model]
return OpenClipAdapter(name=self.model, pretrained=checkpoint)
elif self.make == "x-clip": elif self.make == "x-clip":
return XClipAdapter(XCLIP(**self.base_model_kwargs)) return XClipAdapter(XCLIP(**self.base_model_kwargs))
elif self.make == "coca": elif self.make == "coca":
@@ -134,8 +127,8 @@ class AdapterConfig(BaseModel):
class DiffusionPriorNetworkConfig(BaseModel): class DiffusionPriorNetworkConfig(BaseModel):
dim: int dim: int
depth: int depth: int
max_text_len: Optional[int] = None max_text_len: int = None
num_timesteps: Optional[int] = None num_timesteps: int = None
num_time_embeds: int = 1 num_time_embeds: int = 1
num_image_embeds: int = 1 num_image_embeds: int = 1
num_text_embeds: int = 1 num_text_embeds: int = 1
@@ -158,7 +151,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
return DiffusionPriorNetwork(**kwargs) return DiffusionPriorNetwork(**kwargs)
class DiffusionPriorConfig(BaseModel): class DiffusionPriorConfig(BaseModel):
clip: Optional[AdapterConfig] = None clip: AdapterConfig = None
net: DiffusionPriorNetworkConfig net: DiffusionPriorNetworkConfig
image_embed_dim: int image_embed_dim: int
image_size: int image_size: int
@@ -195,7 +188,7 @@ class DiffusionPriorTrainConfig(BaseModel):
use_ema: bool = True use_ema: bool = True
ema_beta: float = 0.99 ema_beta: float = 0.99
amp: bool = False amp: bool = False
warmup_steps: Optional[int] = None # number of warmup steps warmup_steps: int = None # number of warmup steps
save_every_seconds: int = 3600 # how often to save save_every_seconds: int = 3600 # how often to save
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
best_validation_loss: float = 1e9 # the current best valudation loss observed best_validation_loss: float = 1e9 # the current best valudation loss observed
@@ -228,12 +221,12 @@ class TrainDiffusionPriorConfig(BaseModel):
class UnetConfig(BaseModel): class UnetConfig(BaseModel):
dim: int dim: int
dim_mults: ListOrTuple[int] dim_mults: ListOrTuple[int]
image_embed_dim: Optional[int] = None image_embed_dim: int = None
text_embed_dim: Optional[int] = None text_embed_dim: int = None
cond_on_text_encodings: Optional[bool] = None cond_on_text_encodings: bool = None
cond_dim: Optional[int] = None cond_dim: int = None
channels: int = 3 channels: int = 3
self_attn: SingularOrIterable[bool] = False self_attn: ListOrTuple[int]
attn_dim_head: int = 32 attn_dim_head: int = 32
attn_heads: int = 16 attn_heads: int = 16
init_cross_embed: bool = True init_cross_embed: bool = True
@@ -243,14 +236,14 @@ class UnetConfig(BaseModel):
class DecoderConfig(BaseModel): class DecoderConfig(BaseModel):
unets: ListOrTuple[UnetConfig] unets: ListOrTuple[UnetConfig]
image_size: Optional[int] = None image_size: int = None
image_sizes: ListOrTuple[int] = None image_sizes: ListOrTuple[int] = None
clip: Optional[AdapterConfig] = None # The clip model to use if embeddings are not provided clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3 channels: int = 3
timesteps: int = 1000 timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None sample_timesteps: Optional[SingularOrIterable[int]] = None
loss_type: str = 'l2' loss_type: str = 'l2'
beta_schedule: Optional[ListOrTuple[str]] = None # None means all cosine beta_schedule: ListOrTuple[str] = None # None means all cosine
learned_variance: SingularOrIterable[bool] = True learned_variance: SingularOrIterable[bool] = True
image_cond_drop_prob: float = 0.1 image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5 text_cond_drop_prob: float = 0.5
@@ -278,9 +271,9 @@ class DecoderConfig(BaseModel):
extra = "allow" extra = "allow"
class DecoderDataConfig(BaseModel): class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images webdataset_base_url: str # path to a webdataset with jpg images
img_embeddings_url: Optional[str] = None # path to .npy files with embeddings img_embeddings_url: Optional[str] # path to .npy files with embeddings
text_embeddings_url: Optional[str] = None # path to .npy files with embeddings text_embeddings_url: Optional[str] # path to .npy files with embeddings
num_workers: int = 4 num_workers: int = 4
batch_size: int = 64 batch_size: int = 64
start_shard: int = 0 start_shard: int = 0
@@ -314,26 +307,25 @@ class DecoderTrainConfig(BaseModel):
wd: SingularOrIterable[float] = 0.01 wd: SingularOrIterable[float] = 0.01
warmup_steps: Optional[SingularOrIterable[int]] = None warmup_steps: Optional[SingularOrIterable[int]] = None
find_unused_parameters: bool = True find_unused_parameters: bool = True
static_graph: bool = True
max_grad_norm: SingularOrIterable[float] = 0.5 max_grad_norm: SingularOrIterable[float] = 0.5
save_every_n_samples: int = 100000 save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
cond_scale: Union[float, List[float]] = 1.0 cond_scale: Union[float, List[float]] = 1.0
device: str = 'cuda:0' device: str = 'cuda:0'
epoch_samples: Optional[int] = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: Optional[int] = None # Same as above but for validation. validation_samples: int = None # Same as above but for validation.
save_immediately: bool = False save_immediately: bool = False
use_ema: bool = True use_ema: bool = True
ema_beta: float = 0.999 ema_beta: float = 0.999
amp: bool = False amp: bool = False
unet_training_mask: Optional[ListOrTuple[bool]] = None # If None, use all unets unet_training_mask: ListOrTuple[bool] = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel): class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000 n_evaluation_samples: int = 1000
FID: Optional[Dict[str, Any]] = None FID: Dict[str, Any] = None
IS: Optional[Dict[str, Any]] = None IS: Dict[str, Any] = None
KID: Optional[Dict[str, Any]] = None KID: Dict[str, Any] = None
LPIPS: Optional[Dict[str, Any]] = None LPIPS: Dict[str, Any] = None
class TrainDecoderConfig(BaseModel): class TrainDecoderConfig(BaseModel):
decoder: DecoderConfig decoder: DecoderConfig
@@ -347,14 +339,11 @@ class TrainDecoderConfig(BaseModel):
def from_json_path(cls, json_path): def from_json_path(cls, json_path):
with open(json_path) as f: with open(json_path) as f:
config = json.load(f) config = json.load(f)
print(config)
return cls(**config) return cls(**config)
@model_validator(mode = 'after') @root_validator
def check_has_embeddings(self, m): def check_has_embeddings(cls, values):
# Makes sure that enough information is provided to get the embeddings specified for training # Makes sure that enough information is provided to get the embeddings specified for training
values = dict(self)
data_config, decoder_config = values.get('data'), values.get('decoder') data_config, decoder_config = values.get('data'), values.get('decoder')
if not exists(data_config) or not exists(decoder_config): if not exists(data_config) or not exists(decoder_config):
@@ -379,4 +368,4 @@ class TrainDecoderConfig(BaseModel):
if text_emb_url: if text_emb_url:
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason." assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
return m return values

View File

@@ -9,7 +9,7 @@ from collections.abc import Iterable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import autocast, GradScaler from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
@@ -181,8 +181,7 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6, eps = 1e-6,
max_grad_norm = None, max_grad_norm = None,
group_wd_params = True, group_wd_params = True,
warmup_steps = None, warmup_steps = 1,
cosine_decay_max_steps = None,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -234,11 +233,8 @@ class DiffusionPriorTrainer(nn.Module):
**self.optim_kwargs, **self.optim_kwargs,
**kwargs **kwargs
) )
if exists(cosine_decay_max_steps): self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
@@ -275,7 +271,6 @@ class DiffusionPriorTrainer(nn.Module):
# FIXME: LambdaLR can't be saved due to pickling issues # FIXME: LambdaLR can't be saved due to pickling issues
save_obj = dict( save_obj = dict(
optimizer = self.optimizer.state_dict(), optimizer = self.optimizer.state_dict(),
scheduler = self.scheduler.state_dict(),
warmup_scheduler = self.warmup_scheduler, warmup_scheduler = self.warmup_scheduler,
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(), model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
version = version.parse(__version__), version = version.parse(__version__),
@@ -322,9 +317,7 @@ class DiffusionPriorTrainer(nn.Module):
# unwrap the model when loading from checkpoint # unwrap the model when loading from checkpoint
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict) self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device)) self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
self.optimizer.load_state_dict(loaded_obj['optimizer']) self.optimizer.load_state_dict(loaded_obj['optimizer'])
self.scheduler.load_state_dict(loaded_obj['scheduler'])
# set warmupstep # set warmupstep
if exists(self.warmup_scheduler): if exists(self.warmup_scheduler):
@@ -357,8 +350,7 @@ class DiffusionPriorTrainer(nn.Module):
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy" # accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
if not self.accelerator.optimizer_step_was_skipped: if not self.accelerator.optimizer_step_was_skipped:
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext with self.warmup_scheduler.dampening():
with sched_context():
self.scheduler.step() self.scheduler.step()
if self.use_ema: if self.use_ema:
@@ -441,7 +433,6 @@ class DecoderTrainer(nn.Module):
wd = 1e-2, wd = 1e-2,
eps = 1e-8, eps = 1e-8,
warmup_steps = None, warmup_steps = None,
cosine_decay_max_steps = None,
max_grad_norm = 0.5, max_grad_norm = 0.5,
amp = False, amp = False,
group_wd_params = True, group_wd_params = True,
@@ -463,7 +454,7 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay # be able to finely customize learning rate, weight decay
# per unet # per unet
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps)) lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4' assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
@@ -471,7 +462,7 @@ class DecoderTrainer(nn.Module):
schedulers = [] schedulers = []
warmup_schedulers = [] warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps): for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
if isinstance(unet, nn.Identity): if isinstance(unet, nn.Identity):
optimizers.append(None) optimizers.append(None)
schedulers.append(None) schedulers.append(None)
@@ -487,11 +478,7 @@ class DecoderTrainer(nn.Module):
) )
optimizers.append(optimizer) optimizers.append(optimizer)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
if exists(unet_cosine_decay_max_steps):
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
else:
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler) warmup_schedulers.append(warmup_scheduler)
@@ -571,15 +558,9 @@ class DecoderTrainer(nn.Module):
for ind in range(0, self.num_unets): for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}' optimizer_key = f'optim{ind}'
scheduler_key = f'sched{ind}'
optimizer = getattr(self, optimizer_key) optimizer = getattr(self, optimizer_key)
scheduler = getattr(self, scheduler_key) state_dict = optimizer.state_dict() if optimizer is not None else None
save_obj = {**save_obj, optimizer_key: state_dict}
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
if self.use_ema: if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -600,18 +581,10 @@ class DecoderTrainer(nn.Module):
optimizer_key = f'optim{ind}' optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key) optimizer = getattr(self, optimizer_key)
scheduler_key = f'sched{ind}'
scheduler = getattr(self, scheduler_key)
warmup_scheduler = self.warmup_schedulers[ind] warmup_scheduler = self.warmup_schedulers[ind]
if optimizer is not None:
if exists(optimizer):
optimizer.load_state_dict(loaded_obj[optimizer_key]) optimizer.load_state_dict(loaded_obj[optimizer_key])
if exists(scheduler):
scheduler.load_state_dict(loaded_obj[scheduler_key])
if exists(warmup_scheduler): if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step warmup_scheduler.last_step = last_step

View File

@@ -1 +1 @@
__version__ = '1.15.6' __version__ = '1.4.1'

View File

@@ -11,7 +11,8 @@ import torch.nn.functional as F
from torch.autograd import grad as torch_grad from torch.autograd import grad as torch_grad
import torchvision import torchvision
from einops import rearrange, reduce, repeat, pack, unpack from einops import rearrange, reduce, repeat
from einops_exts import rearrange_many
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
# constants # constants
@@ -407,7 +408,7 @@ class Attention(nn.Module):
x = self.norm(x) x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1) q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
q = q * self.scale q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k) sim = einsum('b h i d, b h j d -> b h i j', q, k)

View File

@@ -26,17 +26,17 @@ setup(
install_requires=[ install_requires=[
'accelerate', 'accelerate',
'click', 'click',
'open-clip-torch>=2.0.0,<3.0.0', 'clip-anytorch',
'clip-anytorch>=2.5.2',
'coca-pytorch>=0.0.5', 'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7', 'ema-pytorch>=0.0.7',
'einops>=0.7.0', 'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader', 'embedding-reader',
'kornia>=0.5.4', 'kornia>=0.5.4',
'numpy', 'numpy',
'packaging', 'packaging',
'pillow', 'pillow',
'pydantic>=2', 'pydantic',
'pytorch-warmup', 'pytorch-warmup',
'resize-right>=0.0.2', 'resize-right>=0.0.2',
'rotary-embedding-torch', 'rotary-embedding-torch',

View File

@@ -134,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
break break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n])) return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True): def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
""" """
Takes example data and generates images from the embeddings Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions Returns three lists: real images, generated images, and captions
@@ -144,9 +144,7 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
if img_embeddings[0] is None: if img_embeddings[0] is None:
# Generate image embeddings from clip # Generate image embeddings from clip
imgs_tensor = torch.stack(real_images) imgs_tensor = torch.stack(real_images)
assert clip is not None, "clip is None, but img_embeddings is None" img_embeddings, *_ = trainer.embed_image(imgs_tensor)
imgs_tensor.to(device=device)
img_embeddings, img_encoding = clip.embed_image(imgs_tensor)
sample_params["image_embed"] = img_embeddings sample_params["image_embed"] = img_embeddings
else: else:
# Then we are using precomputed image embeddings # Then we are using precomputed image embeddings
@@ -155,10 +153,8 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
if condition_on_text_encodings: if condition_on_text_encodings:
if text_embeddings[0] is None: if text_embeddings[0] is None:
# Generate text embeddings from text # Generate text embeddings from text
assert clip is not None, "clip is None, but text_embeddings is None" tokenized_texts = tokenize(txts, truncate=True)
tokenized_texts = tokenize(txts, truncate=True).to(device=device) sample_params["text"] = tokenized_texts
text_embed, text_encodings = clip.embed_text(tokenized_texts)
sample_params["text_encodings"] = text_encodings
else: else:
# Then we are using precomputed text embeddings # Then we are using precomputed text embeddings
text_embeddings = torch.stack(text_embeddings) text_embeddings = torch.stack(text_embeddings)
@@ -170,7 +166,7 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
sample_params["image"] = torch.stack(real_images) sample_params["image"] = torch.stack(real_images)
if device is not None: if device is not None:
sample_params["_device"] = device sample_params["_device"] = device
samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16 samples = trainer.sample(**sample_params)
generated_images = list(samples) generated_images = list(samples)
captions = [text_prepend + txt for txt in txts] captions = [text_prepend + txt for txt in txts]
if match_image_size: if match_image_size:
@@ -178,15 +174,15 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images] real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
return real_images, generated_images, captions return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""): def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
""" """
Generates samples and uses torchvision to put them in a side by side grid for easy viewing Generates samples and uses torchvision to put them in a side by side grid for easy viewing
""" """
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend) real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
""" """
Computes evaluation metrics for the decoder Computes evaluation metrics for the decoder
""" """
@@ -196,7 +192,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=Non
if len(examples) == 0: if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.") print("No data to evaluate. Check that your dataloader has shards.")
return metrics return metrics
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device) real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float) real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8 # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -229,8 +225,8 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=Non
metrics["KID_std"] = kid_std.item() metrics["KID_std"] = kid_std.item()
if exists(LPIPS): if exists(LPIPS):
# Convert from [0, 1] to [-1, 1] # Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1) renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1) renorm_generated_images = generated_images.mul(2).sub(1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync) lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
lpips.to(device=device) lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images) lpips.update(renorm_real_images, renorm_generated_images)
@@ -269,7 +265,6 @@ def train(
accelerator: Accelerator, accelerator: Accelerator,
tracker: Tracker, tracker: Tracker,
inference_device, inference_device,
clip=None,
evaluate_config=None, evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None, validation_samples = None,
@@ -376,19 +371,15 @@ def train(
forward_params['image_embed'] = img_emb forward_params['image_embed'] = img_emb
else: else:
# Forward pass automatically generates embedding # Forward pass automatically generates embedding
assert clip is not None pass
img_embed, img_encoding = clip.embed_image(img)
forward_params['image_embed'] = img_embed
if condition_on_text_encodings: if condition_on_text_encodings:
if has_text_embedding: if has_text_embedding:
forward_params['text_encodings'] = text_emb forward_params['text_encodings'] = text_emb
else: else:
# Then we need to pass the text instead # Then we need to pass the text instead
assert clip is not None tokenized_texts = tokenize(txt, truncate=True)
tokenized_texts = tokenize(txt, truncate=True).to(inference_device)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
text_embed, text_encodings = clip.embed_text(tokenized_texts) forward_params['text'] = tokenized_texts
forward_params['text_encodings'] = text_encodings
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device) loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
trainer.update(unet_number=unet) trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
@@ -428,7 +419,7 @@ def train(
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen) save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
if exists(n_sample_images) and n_sample_images > 0: if exists(n_sample_images) and n_sample_images > 0:
trainer.eval() trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples: if epoch_samples is not None and sample >= epoch_samples:
@@ -471,19 +462,15 @@ def train(
forward_params['image_embed'] = img_emb.float() forward_params['image_embed'] = img_emb.float()
else: else:
# Forward pass automatically generates embedding # Forward pass automatically generates embedding
assert clip is not None pass
img_embed, img_encoding = clip.embed_image(img)
forward_params['image_embed'] = img_embed
if condition_on_text_encodings: if condition_on_text_encodings:
if has_text_embedding: if has_text_embedding:
forward_params['text_encodings'] = text_emb.float() forward_params['text_encodings'] = text_emb.float()
else: else:
# Then we need to pass the text instead # Then we need to pass the text instead
assert clip is not None tokenized_texts = tokenize(txt, truncate=True)
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
text_embed, text_encodings = clip.embed_text(tokenized_texts) forward_params['text'] = tokenized_texts
forward_params['text_encodings'] = text_encodings
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device) loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
average_val_loss_tensor[0, unet-1] += loss average_val_loss_tensor[0, unet-1] += loss
@@ -511,7 +498,7 @@ def train(
if next_task == 'eval': if next_task == 'eval':
if exists(evaluate_config): if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.model_dump(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale) evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
if is_master: if is_master:
tracker.log(evaluation, step=step()) tracker.log(evaluation, step=step())
next_task = 'sample' next_task = 'sample'
@@ -522,8 +509,8 @@ def train(
# Generate examples and save the model if we are the master # Generate examples and save the model if we are the master
# Generate sample images # Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40)) print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ") test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step()) tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
@@ -545,10 +532,9 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
"NumProcesses": accelerator.num_processes, "NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision "MixedPrecision": accelerator.mixed_precision
} }
accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy) tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json') tracker.save_config(config_path, config_name='decoder_config.json')
tracker.add_save_metadata(state_dict_key='config', metadata=config.model_dump()) tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
return tracker return tracker
def initialize_training(config: TrainDecoderConfig, config_path): def initialize_training(config: TrainDecoderConfig, config_path):
@@ -556,7 +542,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
torch.manual_seed(config.seed) torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training # Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60)) init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs]) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
@@ -569,6 +555,10 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# If we are in deepspeed fp16 mode, we must ensure learned variance is off # If we are in deepspeed fp16 mode, we must ensure learned variance is off
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance: if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
raise ValueError("DeepSpeed fp16 mode does not support learned variance") raise ValueError("DeepSpeed fp16 mode does not support learned variance")
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
# This is an invalid configuration until we figure out how to handle this
raise ValueError("DeepSpeed does not support multi-node distributed training")
# Set up data # Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1)) all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
@@ -577,7 +567,6 @@ def initialize_training(config: TrainDecoderConfig, config_path):
shards_per_process = len(all_shards) // world_size shards_per_process = len(all_shards) // world_size
assert shards_per_process > 0, "Not enough shards to split evenly" assert shards_per_process > 0, "Not enough shards to split evenly"
my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process] my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]
dataloaders = create_dataloaders ( dataloaders = create_dataloaders (
available_shards=my_shards, available_shards=my_shards,
img_preproc = config.data.img_preproc, img_preproc = config.data.img_preproc,
@@ -585,16 +574,11 @@ def initialize_training(config: TrainDecoderConfig, config_path):
val_prop = config.data.splits.val, val_prop = config.data.splits.val,
test_prop = config.data.splits.test, test_prop = config.data.splits.test,
n_sample_images=config.train.n_sample_images, n_sample_images=config.train.n_sample_images,
**config.data.model_dump(), **config.data.dict(),
rank = rank, rank = rank,
seed = config.seed, seed = config.seed,
) )
# If clip is in the model, we need to remove it for compatibility with deepspeed
clip = None
if config.decoder.clip is not None:
clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues
config.decoder.clip = None
# Create the decoder model and print basic info # Create the decoder model and print basic info
decoder = config.decoder.create() decoder = config.decoder.create()
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training)) get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
@@ -606,7 +590,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
has_text_embeddings = config.data.text_embeddings_url is not None has_text_embeddings = config.data.text_embeddings_url is not None
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets]) conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
has_clip_model = clip is not None has_clip_model = config.decoder.clip is not None
data_source_string = "" data_source_string = ""
if has_img_embeddings: if has_img_embeddings:
@@ -631,12 +615,11 @@ def initialize_training(config: TrainDecoderConfig, config_path):
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training") accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
train(dataloaders, decoder, accelerator, train(dataloaders, decoder, accelerator,
clip=clip,
tracker=tracker, tracker=tracker,
inference_device=accelerator.device, inference_device=accelerator.device,
evaluate_config=config.evaluate, evaluate_config=config.evaluate,
condition_on_text_encodings=conditioning_on_text, condition_on_text_encodings=conditioning_on_text,
**config.train.model_dump(), **config.train.dict(),
) )
# Create a simple click command line interface to load the config and start the training # Create a simple click command line interface to load the config and start the training