From 58892135d9bcf117921c885dda161c0b67452096 Mon Sep 17 00:00:00 2001 From: Aidan Dempster Date: Sun, 19 Jun 2022 12:25:54 -0400 Subject: [PATCH] Distributed Training of the Decoder (#121) * Converted decoder trainer to use accelerate * Fixed issue where metric evaluation would hang on distributed mode * Implemented functional saving Loading still fails due to some issue with the optimizer * Fixed issue with loading decoders * Fixed issue with tracker config * Fixed issue with amp Updated logging to be more logical * Saving checkpoint now saves position in training as well Fixed an issue with running out of gpu space due to loading weights into the gpu twice * Fixed ema for distributed training * Fixed isue where get_pkg_version was reintroduced * Changed decoder trainer to upload config as a file Fixed issue where loading best would error --- dalle2_pytorch/dalle2_pytorch.py | 5 +- dalle2_pytorch/dataloaders/decoder_loader.py | 3 - dalle2_pytorch/trackers.py | 57 ++- dalle2_pytorch/train_configs.py | 1 + dalle2_pytorch/trainer.py | 76 ++-- dalle2_pytorch/utils.py | 1 + train_decoder.py | 395 ++++++++++++------- 7 files changed, 331 insertions(+), 207 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 3c6afb5..26ba70a 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -2099,7 +2099,8 @@ class Decoder(BaseGaussianDiffusion): text_encodings = None, batch_size = 1, cond_scale = 1., - stop_at_unet_number = None + stop_at_unet_number = None, + distributed = False, ): assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally' @@ -2118,7 +2119,7 @@ class Decoder(BaseGaussianDiffusion): for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)): - context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context() + context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context() with context: lowres_cond_img = None diff --git a/dalle2_pytorch/dataloaders/decoder_loader.py b/dalle2_pytorch/dataloaders/decoder_loader.py index 24a642b..5681e2a 100644 --- a/dalle2_pytorch/dataloaders/decoder_loader.py +++ b/dalle2_pytorch/dataloaders/decoder_loader.py @@ -164,9 +164,6 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface): # There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues. self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler)) - self.append(wds.split_by_node) - self.append(wds.split_by_worker) - self.append(wds.tarfile_to_samples(handler=handler)) self.append(wds.decode("pilrgb", handler=handler)) if embedding_folder_url is not None: diff --git a/dalle2_pytorch/trackers.py b/dalle2_pytorch/trackers.py index 9204f2e..6569f2a 100644 --- a/dalle2_pytorch/trackers.py +++ b/dalle2_pytorch/trackers.py @@ -17,15 +17,15 @@ DEFAULT_DATA_PATH = './.tracker-data' def exists(val): return val is not None -# load state dict functions +# load file functions -def load_wandb_state_dict(run_path, file_path, **kwargs): +def load_wandb_file(run_path, file_path, **kwargs): wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function') file_reference = wandb.restore(file_path, run_path=run_path) - return torch.load(file_reference.name) + return file_reference.name -def load_local_state_dict(file_path, **kwargs): - return torch.load(file_path) +def load_local_file(file_path, **kwargs): + return file_path # base class @@ -55,12 +55,43 @@ class BaseTracker(nn.Module): """ # TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement if recall_source == 'wandb': - return load_wandb_state_dict(*args, **kwargs) + return torch.load(load_wandb_file(*args, **kwargs)) elif recall_source == 'local': - return load_local_state_dict(*args, **kwargs) + return torch.load(load_local_file(*args, **kwargs)) else: raise ValueError('`recall_source` must be one of `wandb` or `local`') + def save_file(self, file_path, **kwargs): + raise NotImplementedError + + def recall_file(self, recall_source, *args, **kwargs): + if recall_source == 'wandb': + return load_wandb_file(*args, **kwargs) + elif recall_source == 'local': + return load_local_file(*args, **kwargs) + else: + raise ValueError('`recall_source` must be one of `wandb` or `local`') + +# Tracker that no-ops all calls except for recall + +class DummyTracker(BaseTracker): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def init(self, config, **kwargs): + pass + + def log(self, log, **kwargs): + pass + + def log_images(self, images, **kwargs): + pass + + def save_state_dict(self, state_dict, relative_path, **kwargs): + pass + + def save_file(self, file_path, **kwargs): + pass # basic stdout class @@ -76,6 +107,10 @@ class ConsoleTracker(BaseTracker): def save_state_dict(self, state_dict, relative_path, **kwargs): torch.save(state_dict, str(self.data_path / relative_path)) + + def save_file(self, file_path, **kwargs): + # This is a no-op for local file systems since it is already saved locally + pass # basic wandb class @@ -107,3 +142,11 @@ class WandbTracker(BaseTracker): full_path = str(self.data_path / relative_path) torch.save(state_dict, full_path) self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path + + def save_file(self, file_path, base_path=None, **kwargs): + """ + Uploads a file from disk to wandb + """ + if base_path is None: + base_path = self.data_path + self.wandb.save(str(file_path), base_path = str(base_path)) diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 28b8f89..56713d6 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -261,6 +261,7 @@ class TrainDecoderConfig(BaseModel): evaluate: DecoderEvaluateConfig tracker: TrackerConfig load: DecoderLoadConfig + seed: int = 0 @classmethod def from_json_path(cls, json_path): diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 6cc609b..7008bf7 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -574,6 +574,7 @@ def decoder_sample_in_chunks(fn): class DecoderTrainer(nn.Module): def __init__( self, + accelerator, decoder, use_ema = True, lr = 1e-4, @@ -588,8 +589,9 @@ class DecoderTrainer(nn.Module): assert isinstance(decoder, Decoder) ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) - self.decoder = decoder - self.num_unets = len(self.decoder.unets) + self.accelerator = accelerator + + self.num_unets = len(decoder.unets) self.use_ema = use_ema self.ema_unets = nn.ModuleList([]) @@ -601,7 +603,9 @@ class DecoderTrainer(nn.Module): lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps)) - for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)): + optimizers = [] + + for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps): optimizer = get_optimizer( unet.parameters(), lr = unet_lr, @@ -611,19 +615,20 @@ class DecoderTrainer(nn.Module): **kwargs ) - setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers + optimizers.append(optimizer) if self.use_ema: self.ema_unets.append(EMA(unet, **ema_kwargs)) - scaler = GradScaler(enabled = amp) - setattr(self, f'scaler{ind}', scaler) - # gradient clipping if needed self.max_grad_norm = max_grad_norm self.register_buffer('step', torch.tensor([0.])) + results = list(self.accelerator.prepare(decoder, *optimizers)) + self.decoder = results.pop(0) + for opt_ind in range(len(optimizers)): + setattr(self, f'optim{opt_ind}', results.pop(0)) def save(self, path, overwrite = True, **kwargs): path = Path(path) @@ -631,47 +636,42 @@ class DecoderTrainer(nn.Module): path.parent.mkdir(parents = True, exist_ok = True) save_obj = dict( - model = self.decoder.state_dict(), + model = self.accelerator.unwrap_model(self.decoder).state_dict(), version = __version__, step = self.step.item(), **kwargs ) for ind in range(0, self.num_unets): - scaler_key = f'scaler{ind}' - optimizer_key = f'scaler{ind}' - scaler = getattr(self, scaler_key) + optimizer_key = f'optim{ind}' optimizer = getattr(self, optimizer_key) - save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} + save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()} if self.use_ema: save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} - torch.save(save_obj, str(path)) + self.accelerator.save(save_obj, str(path)) def load(self, path, only_model = False, strict = True): path = Path(path) assert path.exists() - loaded_obj = torch.load(str(path)) + loaded_obj = torch.load(str(path), map_location = 'cpu') - if version.parse(__version__) != loaded_obj['version']: - print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}') + if version.parse(__version__) != version.parse(loaded_obj['version']): + self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}') - self.decoder.load_state_dict(loaded_obj['model'], strict = strict) + self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict) self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) if only_model: return loaded_obj for ind in range(0, self.num_unets): - scaler_key = f'scaler{ind}' - optimizer_key = f'scaler{ind}' - scaler = getattr(self, scaler_key) + optimizer_key = f'optim{ind}' optimizer = getattr(self, optimizer_key) - scaler.load_state_dict(loaded_obj[scaler_key]) - optimizer.load_state_dict(loaded_obj[optimizer_key]) + self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key]) if self.use_ema: assert 'ema' in loaded_obj @@ -683,29 +683,18 @@ class DecoderTrainer(nn.Module): def unets(self): return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) - def scale(self, loss, *, unet_number): - assert 1 <= unet_number <= self.num_unets - index = unet_number - 1 - scaler = getattr(self, f'scaler{index}') - return scaler.scale(loss) - def update(self, unet_number = None): if self.num_unets == 1: unet_number = default(unet_number, 1) assert exists(unet_number) and 1 <= unet_number <= self.num_unets index = unet_number - 1 - unet = self.decoder.unets[index] optimizer = getattr(self, f'optim{index}') - scaler = getattr(self, f'scaler{index}') if exists(self.max_grad_norm): - scaler.unscale_(optimizer) - nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm) - - scaler.step(optimizer) - scaler.update() + self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients + optimizer.step() optimizer.zero_grad() if self.use_ema: @@ -718,15 +707,17 @@ class DecoderTrainer(nn.Module): @cast_torch_tensor @decoder_sample_in_chunks def sample(self, *args, **kwargs): + distributed = self.accelerator.num_processes > 1 + base_decoder = self.accelerator.unwrap_model(self.decoder) if kwargs.pop('use_non_ema', False) or not self.use_ema: - return self.decoder.sample(*args, **kwargs) + return base_decoder.sample(*args, **kwargs, distributed = distributed) - trainable_unets = self.decoder.unets - self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling + trainable_unets = self.accelerator.unwrap_model(self.decoder).unets + base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling - output = self.decoder.sample(*args, **kwargs) + output = base_decoder.sample(*args, **kwargs, distributed = distributed) - self.decoder.unets = trainable_unets # restore original training unets + base_decoder.unets = trainable_unets # restore original training unets # cast the ema_model unets back to original device for ema in self.ema_unets: @@ -748,13 +739,14 @@ class DecoderTrainer(nn.Module): total_loss = 0. for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): - with autocast(enabled = self.amp): + # with autocast(enabled = self.amp): + with self.accelerator.autocast(): loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs) loss = loss * chunk_size_frac total_loss += loss.item() if self.training: - self.scale(loss, unet_number = unet_number).backward() + self.accelerator.backward(loss) return total_loss diff --git a/dalle2_pytorch/utils.py b/dalle2_pytorch/utils.py index 7208f3e..45e5ee5 100644 --- a/dalle2_pytorch/utils.py +++ b/dalle2_pytorch/utils.py @@ -1,4 +1,5 @@ import time +import importlib # time helpers diff --git a/train_decoder.py b/train_decoder.py index c6ed801..53953b3 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -1,7 +1,8 @@ -from dalle2_pytorch import Unet, Decoder +from pathlib import Path + from dalle2_pytorch.trainer import DecoderTrainer from dalle2_pytorch.dataloaders import create_image_embedding_dataloader -from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker +from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker from dalle2_pytorch.train_configs import TrainDecoderConfig from dalle2_pytorch.utils import Timer, print_ribbon from dalle2_pytorch.dalle2_pytorch import resize_image_to @@ -12,6 +13,8 @@ from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.image.inception import InceptionScore from torchmetrics.image.kid import KernelInceptionDistance from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from accelerate import Accelerator, DistributedDataParallelKwargs +from accelerate.utils import dataclasses as accelerate_dataclasses import webdataset as wds import click @@ -42,6 +45,7 @@ def create_dataloaders( train_prop = 0.75, val_prop = 0.15, test_prop = 0.10, + seed = 0, **kwargs ): """ @@ -52,7 +56,7 @@ def create_dataloaders( num_test = round(test_prop*len(available_shards)) num_val = len(available_shards) - num_train - num_test assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}" - train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0)) + train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(seed)) # The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename. train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split] @@ -117,7 +121,6 @@ def get_example_data(dataloader, device, n=5): captions.extend(list(txt)) if len(images) >= n: break - print("Generated {} examples".format(len(images))) return list(zip(images[:n], embeddings[:n], captions[:n])) def generate_samples(trainer, example_data, text_prepend=""): @@ -155,27 +158,34 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID metrics = {} # Prepare the data examples = get_example_data(dataloader, device, n_evaluation_samples) + if len(examples) == 0: + print("No data to evaluate. Check that your dataloader has shards.") + return metrics real_images, generated_images, captions = generate_samples(trainer, examples) real_images = torch.stack(real_images).to(device=device, dtype=torch.float) generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8 int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8) int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8) + + def null_sync(t, *args, **kwargs): + return [t] + if exists(FID): - fid = FrechetInceptionDistance(**FID) + fid = FrechetInceptionDistance(**FID, dist_sync_fn=null_sync) fid.to(device=device) fid.update(int_real_images, real=True) fid.update(int_generated_images, real=False) metrics["FID"] = fid.compute().item() if exists(IS): - inception = InceptionScore(**IS) + inception = InceptionScore(**IS, dist_sync_fn=null_sync) inception.to(device=device) inception.update(int_real_images) is_mean, is_std = inception.compute() metrics["IS_mean"] = is_mean.item() metrics["IS_std"] = is_std.item() if exists(KID): - kernel_inception = KernelInceptionDistance(**KID) + kernel_inception = KernelInceptionDistance(**KID, dist_sync_fn=null_sync) kernel_inception.to(device=device) kernel_inception.update(int_real_images, real=True) kernel_inception.update(int_generated_images, real=False) @@ -186,39 +196,47 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID # Convert from [0, 1] to [-1, 1] renorm_real_images = real_images.mul(2).sub(1) renorm_generated_images = generated_images.mul(2).sub(1) - lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS) + lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync) lpips.to(device=device) lpips.update(renorm_real_images, renorm_generated_images) metrics["LPIPS"] = lpips.compute().item() + + if trainer.accelerator.num_processes > 1: + # Then we should sync the metrics + metrics_order = sorted(metrics.keys()) + metrics_tensor = torch.zeros(1, len(metrics), device=device, dtype=torch.float) + for i, metric_name in enumerate(metrics_order): + metrics_tensor[0, i] = metrics[metric_name] + metrics_tensor = trainer.accelerator.gather(metrics_tensor) + metrics_tensor = metrics_tensor.mean(dim=0) + for i, metric_name in enumerate(metrics_order): + metrics[metric_name] = metrics_tensor[i].item() return metrics -def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths): +def save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, relative_paths): """ Logs the model with an appropriate method depending on the tracker """ if isinstance(relative_paths, str): relative_paths = [relative_paths] - trainer_state_dict = {} - trainer_state_dict["trainer"] = trainer.state_dict() - trainer_state_dict['epoch'] = epoch - trainer_state_dict['step'] = step - trainer_state_dict['validation_losses'] = validation_losses for relative_path in relative_paths: - tracker.save_state_dict(trainer_state_dict, relative_path) + local_path = str(tracker.data_path / relative_path) + trainer.save(local_path, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses) + tracker.save_file(local_path) def recall_trainer(tracker, trainer, recall_source=None, **load_config): """ Loads the model with an appropriate method depending on the tracker """ - print(print_ribbon(f"Loading model from {recall_source}")) - state_dict = tracker.recall_state_dict(recall_source, **load_config.dict()) - trainer.load_state_dict(state_dict["trainer"]) - print("Model loaded") - return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"] + trainer.accelerator.print(print_ribbon(f"Loading model from {recall_source}")) + local_filepath = tracker.recall_file(recall_source, **load_config) + state_dict = trainer.load(local_filepath) + return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0) def train( dataloaders, decoder, + accelerator, tracker, inference_device, load_config=None, @@ -237,17 +255,30 @@ def train( """ Trains a decoder on a dataset. """ - trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here + is_master = accelerator.process_index == 0 + + trainer = DecoderTrainer( + accelerator, decoder, **kwargs ) + # Set up starting model and parameters based on a recalled state dict - start_step = 0 start_epoch = 0 validation_losses = [] + next_task = 'train' + sample = 0 + val_sample = 0 + step = lambda: int(trainer.step.item()) if exists(load_config) and exists(load_config.source): - start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config) + start_epoch, validation_losses, next_task, recalled_sample = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config.dict()) + if next_task == 'train': + sample = recalled_sample + if next_task == 'val': + val_sample = recalled_sample + accelerator.print(f"Loaded model from {load_config.source} on epoch {start_epoch} with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}") + accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}") trainer.to(device=inference_device) if not exists(unet_training_mask): @@ -255,139 +286,185 @@ def train( unet_training_mask = [True] * trainer.num_unets assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}" - print(print_ribbon("Generating Example Data", repeat=40)) - print("This can take a while to load the shard lists...") - train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images) - test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images) + accelerator.print(print_ribbon("Generating Example Data", repeat=40)) + accelerator.print("This can take a while to load the shard lists...") + if is_master: + train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images) + accelerator.print("Generated training examples") + test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images) + accelerator.print("Generated testing examples") send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr] - step = start_step + sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device) + unet_losses_tensor = torch.zeros(TRAIN_CALC_LOSS_EVERY_ITERS, trainer.num_unets, dtype=torch.float, device=inference_device) for epoch in range(start_epoch, epochs): - print(print_ribbon(f"Starting epoch {epoch}", repeat=40)) + accelerator.print(print_ribbon(f"Starting epoch {epoch}", repeat=40)) timer = Timer() + last_sample = sample + last_snapshot = sample - sample = 0 - last_sample = 0 - last_snapshot = 0 + if next_task == 'train': + for i, (img, emb) in enumerate(dataloaders["train"]): + # We want to count the total number of samples across all processes + sample_length_tensor[0] = len(img) + all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this. + total_samples = all_samples.sum().item() + sample += total_samples + img, emb = send_to_device((img, emb)) - losses = [] + trainer.train() + for unet in range(1, trainer.num_unets+1): + # Check if this is a unet we are training + if not unet_training_mask[unet-1]: # Unet index is the unet number - 1 + continue - for i, (img, emb) in enumerate(dataloaders["train"]): - step += 1 - sample += img.shape[0] - img, emb = send_to_device((img, emb)) - - trainer.train() - for unet in range(1, trainer.num_unets+1): - # Check if this is a unet we are training - if not unet_training_mask[unet-1]: # Unet index is the unet number - 1 - continue + loss = trainer.forward(img, image_embed=emb, unet_number=unet) + trainer.update(unet_number=unet) + unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss + + samples_per_sec = (sample - last_sample) / timer.elapsed() + timer.reset() + last_sample = sample - loss = trainer.forward(img, image_embed=emb, unet_number=unet) - trainer.update(unet_number=unet) - losses.append(loss) + if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0: + # We want to average losses across all processes + unet_all_losses = accelerator.gather(unet_losses_tensor) + mask = unet_all_losses != 0 + unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0) + loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 } + log_data = { + "Epoch": epoch, + "Sample": sample, + "Step": i, + "Samples per second": samples_per_sec, + **loss_map + } + # print(f"I am rank {accelerator.state.process_index}. Example weight: {trainer.decoder.state_dict()['module.unets.0.init_conv.convs.0.weight'][0,0,0,0]}") + if is_master: + tracker.log(log_data, step=step(), verbose=True) - samples_per_sec = (sample - last_sample) / timer.elapsed() + if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope + # It is difficult to gather this kind of info on the accelerator, so we have to do it on the master + print("Saving snapshot") + last_snapshot = sample + # We need to know where the model should be saved + save_paths = [] + if save_latest: + save_paths.append("latest.pth") + if save_all: + save_paths.append(f"checkpoints/epoch_{epoch}_step_{step()}.pth") + save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths) + if exists(n_sample_images) and n_sample_images > 0: + trainer.eval() + train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ") + tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) + + if epoch_samples is not None and sample >= epoch_samples: + break + next_task = 'val' + sample = 0 - timer.reset() - last_sample = sample + all_average_val_losses = None + if next_task == 'val': + trainer.eval() + accelerator.print(print_ribbon(f"Starting Validation {epoch}", repeat=40)) + last_val_sample = val_sample + val_sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device) + average_val_loss_tensor = torch.zeros(1, trainer.num_unets, dtype=torch.float, device=inference_device) + timer = Timer() + accelerator.wait_for_everyone() + i = 0 + for i, (img, emb, txt) in enumerate(dataloaders["val"]): + val_sample_length_tensor[0] = len(img) + all_samples = accelerator.gather(val_sample_length_tensor) + total_samples = all_samples.sum().item() + val_sample += total_samples + img, emb = send_to_device((img, emb)) - if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0: - average_loss = sum(losses) / len(losses) - log_data = { - "Training loss": average_loss, - "Epoch": epoch, - "Sample": sample, - "Step": i, - "Samples per second": samples_per_sec - } - tracker.log(log_data, step=step, verbose=True) - losses = [] + for unet in range(1, len(decoder.unets)+1): + if not unet_training_mask[unet-1]: # Unet index is the unet number - 1 + # No need to evaluate an unchanging unet + continue + + loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet) + average_val_loss_tensor[0, unet-1] += loss - if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope - last_snapshot = sample - # We need to know where the model should be saved + if i % VALID_CALC_LOSS_EVERY_ITERS == 0: + samples_per_sec = (val_sample - last_val_sample) / timer.elapsed() + timer.reset() + last_val_sample = val_sample + accelerator.print(f"Epoch {epoch}/{epochs} Val Step {i} - Sample {val_sample} - {samples_per_sec:.2f} samples/sec") + accelerator.print(f"Loss: {(average_val_loss_tensor / (i+1))}") + accelerator.print("") + + if validation_samples is not None and val_sample >= validation_samples: + break + print(f"Rank {accelerator.state.process_index} finished validation after {i} steps") + accelerator.wait_for_everyone() + average_val_loss_tensor /= i+1 + # Gather all the average loss tensors + all_average_val_losses = accelerator.gather(average_val_loss_tensor) + if is_master: + unet_average_val_loss = all_average_val_losses.mean(dim=0) + val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 } + tracker.log(val_loss_map, step=step(), verbose=True) + next_task = 'eval' + + if next_task == 'eval': + if exists(evaluate_config): + accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) + evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict()) + if is_master: + tracker.log(evaluation, step=step(), verbose=True) + next_task = 'sample' + val_sample = 0 + + if next_task == 'sample': + if is_master: + # Generate examples and save the model if we are the master + # Generate sample images + print(print_ribbon(f"Sampling Set {epoch}", repeat=40)) + test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ") + train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ") + tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step()) + tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) + + print(print_ribbon(f"Starting Saving {epoch}", repeat=40)) + # Get the same paths save_paths = [] if save_latest: save_paths.append("latest.pth") - if save_all: - save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth") + if all_average_val_losses is not None: + average_loss = all_average_val_losses.mean(dim=0).item() + if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)): + save_paths.append("best.pth") + validation_losses.append(average_loss) + save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths) + next_task = 'train' - save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths) - - if exists(n_sample_images) and n_sample_images > 0: - trainer.eval() - train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ") - tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step) - - if exists(epoch_samples) and sample >= epoch_samples: - break - - trainer.eval() - print(print_ribbon(f"Starting Validation {epoch}", repeat=40)) - with torch.no_grad(): - sample = 0 - average_loss = 0 - timer = Timer() - for i, (img, emb, *_) in enumerate(dataloaders["val"]): - sample += img.shape[0] - img, emb = send_to_device((img, emb)) - - for unet in range(1, len(decoder.unets)+1): - loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet) - average_loss += loss - - if i % VALID_CALC_LOSS_EVERY_ITERS == 0: - print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec") - print(f"Loss: {average_loss / (i+1)}") - print("") - - if exists(validation_samples) and sample >= validation_samples: - break - - average_loss /= i+1 - log_data = { - "Validation loss": average_loss - } - tracker.log(log_data, step=step, verbose=True) - - # Compute evaluation metrics - if exists(evaluate_config): - print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) - evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict()) - tracker.log(evaluation, step=step, verbose=True) - - # Generate sample images - print(print_ribbon(f"Sampling Set {epoch}", repeat=40)) - test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ") - train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ") - tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step) - tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step) - - print(print_ribbon(f"Starting Saving {epoch}", repeat=40)) - # Get the same paths - save_paths = [] - if save_latest: - save_paths.append("latest.pth") - if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)): - save_paths.append("best.pth") - validation_losses.append(average_loss) - save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths) - -def create_tracker(config, tracker_type=None, data_path=None, **kwargs): +def create_tracker(accelerator, config, config_path, tracker_type=None, data_path=None): """ Creates a tracker of the specified type and initializes special features based on the full config """ tracker_config = config.tracker - init_config = {} + accelerator_config = { + "Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO, + "DistributedType": accelerator.distributed_type, + "NumProcesses": accelerator.num_processes, + "MixedPrecision": accelerator.mixed_precision + } + init_config = { "config": {**config.dict(), **accelerator_config} } + data_path = data_path or tracker_config.data_path + tracker_type = tracker_type or tracker_config.tracker_type - if exists(tracker_config.init_config): - init_config["config"] = tracker_config.init_config - - if tracker_type == "console": - tracker = ConsoleTracker(**init_config) + if tracker_type == "dummy": + tracker = DummyTracker(data_path) + tracker.init(**init_config) + elif tracker_type == "console": + tracker = ConsoleTracker(data_path) + tracker.init(**init_config) elif tracker_type == "wandb": # We need to initialize the resume state here load_config = config.load @@ -401,51 +478,63 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs): init_config["project"] = tracker_config.wandb_project tracker = WandbTracker(data_path) tracker.init(**init_config) + tracker.save_file(str(config_path.absolute()), str(config_path.parent.absolute())) else: raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer") return tracker -def initialize_training(config): - # Create the save path - if "cuda" in config.train.device: - assert torch.cuda.is_available(), "CUDA is not available" - device = torch.device(config.train.device) - torch.cuda.set_device(device) - all_shards = list(range(config.data.start_shard, config.data.end_shard + 1)) +def initialize_training(config, config_path): + # Make sure if we are not loading, distributed models are initialized to the same values + torch.manual_seed(config.seed) + # Set up accelerator for configurable distributed training + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) + + # Set up data + all_shards = list(range(config.data.start_shard, config.data.end_shard + 1)) + world_size = accelerator.num_processes + rank = accelerator.process_index + shards_per_process = len(all_shards) // world_size + assert shards_per_process > 0, "Not enough shards to split evenly" + my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process] dataloaders = create_dataloaders ( - available_shards=all_shards, + available_shards=my_shards, img_preproc = config.data.img_preproc, train_prop = config.data.splits.train, val_prop = config.data.splits.val, test_prop = config.data.splits.test, n_sample_images=config.train.n_sample_images, - **config.data.dict() + **config.data.dict(), + rank = rank, + seed = config.seed, ) - decoder = config.decoder.create().to(device = device) + # Create the decoder model and print basic info + decoder = config.decoder.create() num_parameters = sum(p.numel() for p in decoder.parameters()) - print(print_ribbon("Loaded Config", repeat=40)) - print(f"Number of parameters: {num_parameters}") - tracker = create_tracker(config, **config.tracker.dict()) + # Create and initialize the tracker if we are the master + tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy") - train(dataloaders, decoder, + accelerator.print(print_ribbon("Loaded Config", repeat=40)) + accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training") + accelerator.print(f"Number of parameters: {num_parameters}") + train(dataloaders, decoder, accelerator, tracker=tracker, - inference_device=device, + inference_device=accelerator.device, load_config=config.load, evaluate_config=config.evaluate, **config.train.dict(), ) - + # Create a simple click command line interface to load the config and start the training @click.command() @click.option("--config_file", default="./train_decoder_config.json", help="Path to config file") def main(config_file): - print("Recalling config from {}".format(config_file)) - config = TrainDecoderConfig.from_json_path(config_file) - initialize_training(config) - + config_file_path = Path(config_file) + config = TrainDecoderConfig.from_json_path(str(config_file_path)) + initialize_training(config, config_path=config_file_path) if __name__ == "__main__": main()