mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 17:14:38 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6ea65f59cc |
@@ -1076,7 +1076,6 @@ This library would not have gotten to this working state without the help of
|
|||||||
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||||
- [x] cross embed layers for downsampling, as an option
|
- [x] cross embed layers for downsampling, as an option
|
||||||
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||||
- [x] use pydantic for config drive training
|
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ For more complex configuration, we provide the option of using a configuration f
|
|||||||
|
|
||||||
### Decoder Trainer
|
### Decoder Trainer
|
||||||
|
|
||||||
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
|
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.json.example).
|
||||||
|
|
||||||
**<ins>Unets</ins>:**
|
**<ins>Unets</ins>:**
|
||||||
|
|
||||||
|
|||||||
82
configs/decoder_defaults.py
Normal file
82
configs/decoder_defaults.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""
|
||||||
|
Defines the default values for the decoder config
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
class ConfigField(Enum):
|
||||||
|
REQUIRED = 0 # This had more options. It's a bit unnecessary now, but I can't think of a better way to do it.
|
||||||
|
|
||||||
|
default_config = {
|
||||||
|
"unets": ConfigField.REQUIRED,
|
||||||
|
"decoder": {
|
||||||
|
"image_sizes": ConfigField.REQUIRED, # The side lengths of the upsampled image at the end of each unet
|
||||||
|
"image_size": ConfigField.REQUIRED, # Usually the same as image_sizes[-1] I think
|
||||||
|
"channels": 3,
|
||||||
|
"timesteps": 1000,
|
||||||
|
"loss_type": "l2",
|
||||||
|
"beta_schedule": "cosine",
|
||||||
|
"learned_variance": True
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"webdataset_base_url": ConfigField.REQUIRED, # Path to a webdataset with jpg images
|
||||||
|
"embeddings_url": ConfigField.REQUIRED, # Path to .npy files with embeddings
|
||||||
|
"num_workers": 4,
|
||||||
|
"batch_size": 64,
|
||||||
|
"start_shard": 0,
|
||||||
|
"end_shard": 9999999,
|
||||||
|
"shard_width": 6,
|
||||||
|
"index_width": 4,
|
||||||
|
"splits": {
|
||||||
|
"train": 0.75,
|
||||||
|
"val": 0.15,
|
||||||
|
"test": 0.1
|
||||||
|
},
|
||||||
|
"shuffle_train": True,
|
||||||
|
"resample_train": False,
|
||||||
|
"preprocessing": {
|
||||||
|
"ToTensor": True
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"train": {
|
||||||
|
"epochs": 20,
|
||||||
|
"lr": 1e-4,
|
||||||
|
"wd": 0.01,
|
||||||
|
"max_grad_norm": 0.5,
|
||||||
|
"save_every_n_samples": 100000,
|
||||||
|
"n_sample_images": 6, # The number of example images to produce when sampling the train and test dataset
|
||||||
|
"device": "cuda:0",
|
||||||
|
"epoch_samples": None, # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||||
|
"validation_samples": None, # Same as above but for validation.
|
||||||
|
"use_ema": True,
|
||||||
|
"ema_beta": 0.99,
|
||||||
|
"amp": False,
|
||||||
|
"save_all": False, # Whether to preserve all checkpoints
|
||||||
|
"save_latest": True, # Whether to always save the latest checkpoint
|
||||||
|
"save_best": True, # Whether to save the best checkpoint
|
||||||
|
"unet_training_mask": None # If None, use all unets
|
||||||
|
},
|
||||||
|
"evaluate": {
|
||||||
|
"n_evalation_samples": 1000,
|
||||||
|
"FID": None,
|
||||||
|
"IS": None,
|
||||||
|
"KID": None,
|
||||||
|
"LPIPS": None
|
||||||
|
},
|
||||||
|
"tracker": {
|
||||||
|
"tracker_type": "console", # Decoder currently supports console and wandb
|
||||||
|
"data_path": "./models", # The path where files will be saved locally
|
||||||
|
|
||||||
|
"wandb_entity": "", # Only needs to be set if tracker_type is wandb
|
||||||
|
"wandb_project": "",
|
||||||
|
|
||||||
|
"verbose": False # Whether to print console logging for non-console trackers
|
||||||
|
},
|
||||||
|
"load": {
|
||||||
|
"source": None, # Supports file and wandb
|
||||||
|
|
||||||
|
"run_path": "", # Used only if source is wandb
|
||||||
|
"file_path": "", # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
|
||||||
|
|
||||||
|
"resume": False # If using wandb, whether to resume the run
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,7 @@
|
|||||||
],
|
],
|
||||||
"decoder": {
|
"decoder": {
|
||||||
"image_sizes": [64],
|
"image_sizes": [64],
|
||||||
|
"image_size": [64],
|
||||||
"channels": 3,
|
"channels": 3,
|
||||||
"timesteps": 1000,
|
"timesteps": 1000,
|
||||||
"loss_type": "l2",
|
"loss_type": "l2",
|
||||||
@@ -62,7 +63,7 @@
|
|||||||
"unet_training_mask": [true]
|
"unet_training_mask": [true]
|
||||||
},
|
},
|
||||||
"evaluate": {
|
"evaluate": {
|
||||||
"n_evaluation_samples": 1000,
|
"n_evalation_samples": 1000,
|
||||||
"FID": {
|
"FID": {
|
||||||
"feature": 64
|
"feature": 64
|
||||||
},
|
},
|
||||||
@@ -1,125 +1,47 @@
|
|||||||
import json
|
|
||||||
from torchvision import transforms as T
|
from torchvision import transforms as T
|
||||||
from pydantic import BaseModel, validator, root_validator
|
from configs.decoder_defaults import default_config, ConfigField
|
||||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
|
||||||
|
|
||||||
def exists(val):
|
class TrainDecoderConfig:
|
||||||
return val is not None
|
def __init__(self, config):
|
||||||
|
self.config = self.map_config(config, default_config)
|
||||||
|
|
||||||
def default(val, d):
|
def map_config(self, config, defaults):
|
||||||
return val if exists(val) else d
|
"""
|
||||||
|
Returns a dictionary containing all config options in the union of config and defaults.
|
||||||
|
If the config value is an array, apply the default value to each element.
|
||||||
|
If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config
|
||||||
|
"""
|
||||||
|
def _check_option(option, option_config, option_defaults):
|
||||||
|
for key, value in option_defaults.items():
|
||||||
|
if key not in option_config:
|
||||||
|
if value == ConfigField.REQUIRED:
|
||||||
|
raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option))
|
||||||
|
option_config[key] = value
|
||||||
|
|
||||||
|
for key, value in defaults.items():
|
||||||
|
if key not in config:
|
||||||
|
# Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error
|
||||||
|
if value == ConfigField.REQUIRED:
|
||||||
|
raise RuntimeError("Required config value '{}' not supplied".format(key))
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
config[key] = {}
|
||||||
|
elif isinstance(value, list):
|
||||||
|
config[key] = [{}]
|
||||||
|
# Config[key] is now either a dict, list of dicts, or an object that cannot be checked.
|
||||||
|
# If it is a list, then we need to check each element
|
||||||
|
if isinstance(value, list):
|
||||||
|
assert isinstance(config[key], list)
|
||||||
|
for element in config[key]:
|
||||||
|
_check_option(key, element, value[0])
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
_check_option(key, config[key], value)
|
||||||
|
# This object does not support checking
|
||||||
|
return config
|
||||||
|
|
||||||
class UnetConfig(BaseModel):
|
def get_preprocessing(self):
|
||||||
dim: int
|
"""
|
||||||
dim_mults: List[int]
|
Takes the preprocessing dictionary and converts it to a composition of torchvision transforms
|
||||||
image_embed_dim: int = None
|
"""
|
||||||
cond_dim: int = None
|
|
||||||
channels: int = 3
|
|
||||||
attn_dim_head: int = 32
|
|
||||||
attn_heads: int = 16
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
extra = "allow"
|
|
||||||
|
|
||||||
class DecoderConfig(BaseModel):
|
|
||||||
image_size: int = None
|
|
||||||
image_sizes: Union[List[int], Tuple[int]] = None
|
|
||||||
channels: int = 3
|
|
||||||
timesteps: int = 1000
|
|
||||||
loss_type: str = 'l2'
|
|
||||||
beta_schedule: str = 'cosine'
|
|
||||||
learned_variance: bool = True
|
|
||||||
|
|
||||||
@validator('image_sizes')
|
|
||||||
def check_image_sizes(cls, image_sizes, values):
|
|
||||||
if exists(values.get('image_size')) ^ exists(image_sizes):
|
|
||||||
return image_sizes
|
|
||||||
raise ValueError('either image_size or image_sizes is required, but not both')
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
extra = "allow"
|
|
||||||
|
|
||||||
class TrainSplitConfig(BaseModel):
|
|
||||||
train: float = 0.75
|
|
||||||
val: float = 0.15
|
|
||||||
test: float = 0.1
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validate_all(cls, fields):
|
|
||||||
if sum([*fields.values()]) != 1.:
|
|
||||||
raise ValueError(f'{fields.keys()} must sum to 1.0')
|
|
||||||
return fields
|
|
||||||
|
|
||||||
class DecoderDataConfig(BaseModel):
|
|
||||||
webdataset_base_url: str # path to a webdataset with jpg images
|
|
||||||
embeddings_url: str # path to .npy files with embeddings
|
|
||||||
num_workers: int = 4
|
|
||||||
batch_size: int = 64
|
|
||||||
start_shard: int = 0
|
|
||||||
end_shard: int = 9999999
|
|
||||||
shard_width: int = 6
|
|
||||||
index_width: int = 4
|
|
||||||
splits: TrainSplitConfig
|
|
||||||
shuffle_train: bool = True
|
|
||||||
resample_train: bool = False
|
|
||||||
preprocessing: Dict[str, Any] = {'ToTensor': True}
|
|
||||||
|
|
||||||
class DecoderTrainConfig(BaseModel):
|
|
||||||
epochs: int = 20
|
|
||||||
lr: float = 1e-4
|
|
||||||
wd: float = 0.01
|
|
||||||
max_grad_norm: float = 0.5
|
|
||||||
save_every_n_samples: int = 100000
|
|
||||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
|
||||||
device: str = 'cuda:0'
|
|
||||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
|
||||||
validation_samples: int = None # Same as above but for validation.
|
|
||||||
use_ema: bool = True
|
|
||||||
ema_beta: float = 0.99
|
|
||||||
amp: bool = False
|
|
||||||
save_all: bool = False # Whether to preserve all checkpoints
|
|
||||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
|
||||||
save_best: bool = True # Whether to save the best checkpoint
|
|
||||||
unet_training_mask: List[bool] = None # If None, use all unets
|
|
||||||
|
|
||||||
class DecoderEvaluateConfig(BaseModel):
|
|
||||||
n_evaluation_samples: int = 1000
|
|
||||||
FID: Dict[str, Any] = None
|
|
||||||
IS: Dict[str, Any] = None
|
|
||||||
KID: Dict[str, Any] = None
|
|
||||||
LPIPS: Dict[str, Any] = None
|
|
||||||
|
|
||||||
class TrackerConfig(BaseModel):
|
|
||||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
|
||||||
data_path: str = './models' # The path where files will be saved locally
|
|
||||||
init_config: Dict[str, Any] = None
|
|
||||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
|
||||||
wandb_project: str = ''
|
|
||||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
|
||||||
|
|
||||||
class DecoderLoadConfig(BaseModel):
|
|
||||||
source: str = None # Supports file and wandb
|
|
||||||
run_path: str = '' # Used only if source is wandb
|
|
||||||
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
|
|
||||||
resume: bool = False # If using wandb, whether to resume the run
|
|
||||||
|
|
||||||
class TrainDecoderConfig(BaseModel):
|
|
||||||
unets: List[UnetConfig]
|
|
||||||
decoder: DecoderConfig
|
|
||||||
data: DecoderDataConfig
|
|
||||||
train: DecoderTrainConfig
|
|
||||||
evaluate: DecoderEvaluateConfig
|
|
||||||
tracker: TrackerConfig
|
|
||||||
load: DecoderLoadConfig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_json_path(cls, json_path):
|
|
||||||
with open(json_path) as f:
|
|
||||||
config = json.load(f)
|
|
||||||
return cls(**config)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def img_preproc(self):
|
|
||||||
def _get_transformation(transformation_name, **kwargs):
|
def _get_transformation(transformation_name, **kwargs):
|
||||||
if transformation_name == "RandomResizedCrop":
|
if transformation_name == "RandomResizedCrop":
|
||||||
return T.RandomResizedCrop(**kwargs)
|
return T.RandomResizedCrop(**kwargs)
|
||||||
@@ -128,8 +50,13 @@ class TrainDecoderConfig(BaseModel):
|
|||||||
elif transformation_name == "ToTensor":
|
elif transformation_name == "ToTensor":
|
||||||
return T.ToTensor()
|
return T.ToTensor()
|
||||||
|
|
||||||
transforms = []
|
transformations = []
|
||||||
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
|
for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items():
|
||||||
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
if isinstance(transformation_kwargs, dict):
|
||||||
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
transformations.append(_get_transformation(transformation_name, **transformation_kwargs))
|
||||||
return T.Compose(transforms)
|
else:
|
||||||
|
transformations.append(_get_transformation(transformation_name))
|
||||||
|
return T.Compose(transformations)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.config[key]
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.4.2',
|
version = '0.3.9',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
@@ -32,7 +32,6 @@ setup(
|
|||||||
'kornia>=0.5.4',
|
'kornia>=0.5.4',
|
||||||
'numpy',
|
'numpy',
|
||||||
'pillow',
|
'pillow',
|
||||||
'pydantic',
|
|
||||||
'resize-right>=0.0.2',
|
'resize-right>=0.0.2',
|
||||||
'rotary-embedding-torch',
|
'rotary-embedding-torch',
|
||||||
'torch>=1.10',
|
'torch>=1.10',
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
|||||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||||
from dalle2_pytorch.utils import Timer
|
from dalle2_pytorch.utils import Timer
|
||||||
|
|
||||||
|
import json
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch
|
import torch
|
||||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||||
@@ -89,11 +90,11 @@ def create_dataloaders(
|
|||||||
def create_decoder(device, decoder_config, unets_config):
|
def create_decoder(device, decoder_config, unets_config):
|
||||||
"""Creates a sample decoder"""
|
"""Creates a sample decoder"""
|
||||||
|
|
||||||
unets = [Unet(**config.dict()) for config in unets_config]
|
unets = [Unet(**config) for config in unets_config]
|
||||||
|
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
unet=unets,
|
unet=unets,
|
||||||
**decoder_config.dict()
|
**decoder_config
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder.to(device=device)
|
decoder.to(device=device)
|
||||||
@@ -153,13 +154,13 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
|
|||||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||||
return grid_images, captions
|
return grid_images, captions
|
||||||
|
|
||||||
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||||
"""
|
"""
|
||||||
Computes evaluation metrics for the decoder
|
Computes evaluation metrics for the decoder
|
||||||
"""
|
"""
|
||||||
metrics = {}
|
metrics = {}
|
||||||
# Prepare the data
|
# Prepare the data
|
||||||
examples = get_example_data(dataloader, device, n_evaluation_samples)
|
examples = get_example_data(dataloader, device, n_evalation_samples)
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples)
|
real_images, generated_images, captions = generate_samples(trainer, examples)
|
||||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||||
@@ -251,8 +252,8 @@ def train(
|
|||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
validation_losses = []
|
validation_losses = []
|
||||||
|
|
||||||
if exists(load_config) and exists(load_config.source):
|
if exists(load_config) and exists(load_config["source"]):
|
||||||
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
|
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
|
||||||
trainer.to(device=inference_device)
|
trainer.to(device=inference_device)
|
||||||
|
|
||||||
if not exists(unet_training_mask):
|
if not exists(unet_training_mask):
|
||||||
@@ -270,6 +271,7 @@ def train(
|
|||||||
|
|
||||||
for epoch in range(start_epoch, epochs):
|
for epoch in range(start_epoch, epochs):
|
||||||
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
@@ -278,13 +280,11 @@ def train(
|
|||||||
last_snapshot = 0
|
last_snapshot = 0
|
||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
|
|
||||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
for i, (img, emb) in enumerate(dataloaders["train"]):
|
||||||
step += 1
|
step += 1
|
||||||
sample += img.shape[0]
|
sample += img.shape[0]
|
||||||
img, emb = send_to_device((img, emb))
|
img, emb = send_to_device((img, emb))
|
||||||
|
|
||||||
trainer.train()
|
|
||||||
for unet in range(1, trainer.num_unets+1):
|
for unet in range(1, trainer.num_unets+1):
|
||||||
# Check if this is a unet we are training
|
# Check if this is a unet we are training
|
||||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||||
@@ -299,7 +299,7 @@ def train(
|
|||||||
timer.reset()
|
timer.reset()
|
||||||
last_sample = sample
|
last_sample = sample
|
||||||
|
|
||||||
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
|
if i % CALC_LOSS_EVERY_ITERS == 0:
|
||||||
average_loss = sum(losses) / len(losses)
|
average_loss = sum(losses) / len(losses)
|
||||||
log_data = {
|
log_data = {
|
||||||
"Training loss": average_loss,
|
"Training loss": average_loss,
|
||||||
@@ -319,12 +319,11 @@ def train(
|
|||||||
save_paths.append("latest.pth")
|
save_paths.append("latest.pth")
|
||||||
if save_all:
|
if save_all:
|
||||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
||||||
|
|
||||||
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
||||||
|
|
||||||
if exists(n_sample_images) and n_sample_images > 0:
|
if exists(n_sample_images) and n_sample_images > 0:
|
||||||
trainer.eval()
|
trainer.eval()
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||||
|
trainer.train()
|
||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
||||||
|
|
||||||
if exists(epoch_samples) and sample >= epoch_samples:
|
if exists(epoch_samples) and sample >= epoch_samples:
|
||||||
@@ -359,6 +358,7 @@ def train(
|
|||||||
tracker.log(log_data, step=step, verbose=True)
|
tracker.log(log_data, step=step, verbose=True)
|
||||||
|
|
||||||
# Compute evaluation metrics
|
# Compute evaluation metrics
|
||||||
|
trainer.eval()
|
||||||
if exists(evaluate_config):
|
if exists(evaluate_config):
|
||||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
||||||
@@ -385,25 +385,21 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
|
|||||||
"""
|
"""
|
||||||
Creates a tracker of the specified type and initializes special features based on the full config
|
Creates a tracker of the specified type and initializes special features based on the full config
|
||||||
"""
|
"""
|
||||||
tracker_config = config.tracker
|
tracker_config = config["tracker"]
|
||||||
init_config = {}
|
init_config = {}
|
||||||
|
init_config["config"] = config.config
|
||||||
if exists(tracker_config.init_config):
|
|
||||||
init_config["config"] = tracker_config.init_config
|
|
||||||
|
|
||||||
if tracker_type == "console":
|
if tracker_type == "console":
|
||||||
tracker = ConsoleTracker(**init_config)
|
tracker = ConsoleTracker(**init_config)
|
||||||
elif tracker_type == "wandb":
|
elif tracker_type == "wandb":
|
||||||
# We need to initialize the resume state here
|
# We need to initialize the resume state here
|
||||||
load_config = config.load
|
load_config = config["load"]
|
||||||
if load_config.source == "wandb" and load_config.resume:
|
if load_config["source"] == "wandb" and load_config["resume"]:
|
||||||
# Then we are resuming the run load_config["run_path"]
|
# Then we are resuming the run load_config["run_path"]
|
||||||
run_id = load_config.run_path.split("/")[-1]
|
run_id = config["resume"]["wandb_run_path"].split("/")[-1]
|
||||||
init_config["id"] = run_id
|
init_config["id"] = run_id
|
||||||
init_config["resume"] = "must"
|
init_config["resume"] = "must"
|
||||||
|
init_config["entity"] = tracker_config["wandb_entity"]
|
||||||
init_config["entity"] = tracker_config.wandb_entity
|
init_config["project"] = tracker_config["wandb_project"]
|
||||||
init_config["project"] = tracker_config.wandb_project
|
|
||||||
tracker = WandbTracker(data_path)
|
tracker = WandbTracker(data_path)
|
||||||
tracker.init(**init_config)
|
tracker.init(**init_config)
|
||||||
else:
|
else:
|
||||||
@@ -412,35 +408,35 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
|
|||||||
|
|
||||||
def initialize_training(config):
|
def initialize_training(config):
|
||||||
# Create the save path
|
# Create the save path
|
||||||
if "cuda" in config.train.device:
|
if "cuda" in config["train"]["device"]:
|
||||||
assert torch.cuda.is_available(), "CUDA is not available"
|
assert torch.cuda.is_available(), "CUDA is not available"
|
||||||
device = torch.device(config.train.device)
|
device = torch.device(config["train"]["device"])
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1))
|
||||||
|
|
||||||
dataloaders = create_dataloaders (
|
dataloaders = create_dataloaders (
|
||||||
available_shards=all_shards,
|
available_shards=all_shards,
|
||||||
img_preproc = config.img_preproc,
|
img_preproc = config.get_preprocessing(),
|
||||||
train_prop = config.data.splits.train,
|
train_prop = config["data"]["splits"]["train"],
|
||||||
val_prop = config.data.splits.val,
|
val_prop = config["data"]["splits"]["val"],
|
||||||
test_prop = config.data.splits.test,
|
test_prop = config["data"]["splits"]["test"],
|
||||||
n_sample_images=config.train.n_sample_images,
|
n_sample_images=config["train"]["n_sample_images"],
|
||||||
**config.data.dict()
|
**config["data"]
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder = create_decoder(device, config.decoder, config.unets)
|
decoder = create_decoder(device, config["decoder"], config["unets"])
|
||||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||||
print(print_ribbon("Loaded Config", repeat=40))
|
print(print_ribbon("Loaded Config", repeat=40))
|
||||||
print(f"Number of parameters: {num_parameters}")
|
print(f"Number of parameters: {num_parameters}")
|
||||||
|
|
||||||
tracker = create_tracker(config, **config.tracker.dict())
|
tracker = create_tracker(config, **config["tracker"])
|
||||||
|
|
||||||
train(dataloaders, decoder,
|
train(dataloaders, decoder,
|
||||||
tracker=tracker,
|
tracker=tracker,
|
||||||
inference_device=device,
|
inference_device=device,
|
||||||
load_config=config.load,
|
load_config=config["load"],
|
||||||
evaluate_config=config.evaluate,
|
evaluate_config=config["evaluate"],
|
||||||
**config.train.dict(),
|
**config["train"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a simple click command line interface to load the config and start the training
|
# Create a simple click command line interface to load the config and start the training
|
||||||
@@ -448,7 +444,9 @@ def initialize_training(config):
|
|||||||
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
||||||
def main(config_file):
|
def main(config_file):
|
||||||
print("Recalling config from {}".format(config_file))
|
print("Recalling config from {}".format(config_file))
|
||||||
config = TrainDecoderConfig.from_json_path(config_file)
|
with open(config_file) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
config = TrainDecoderConfig(config)
|
||||||
initialize_training(config)
|
initialize_training(config)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user