From 5c520db82549cb7d1bee6f17e7e4399dd5cc9716 Mon Sep 17 00:00:00 2001 From: Aidan Dempster Date: Fri, 8 Jul 2022 21:18:08 -0400 Subject: [PATCH] Added deepspeed support (#195) --- dalle2_pytorch/dataloaders/decoder_loader.py | 3 +- dalle2_pytorch/trainer.py | 37 ++++++++++++++++++-- train_decoder.py | 22 ++++++++++-- 3 files changed, 56 insertions(+), 6 deletions(-) diff --git a/dalle2_pytorch/dataloaders/decoder_loader.py b/dalle2_pytorch/dataloaders/decoder_loader.py index 572036b..6b679e6 100644 --- a/dalle2_pytorch/dataloaders/decoder_loader.py +++ b/dalle2_pytorch/dataloaders/decoder_loader.py @@ -1,6 +1,7 @@ import os import webdataset as wds import torch +from torch.utils.data import DataLoader import numpy as np import fsspec import shutil @@ -255,7 +256,7 @@ def create_image_embedding_dataloader( ) if shuffle_num is not None and shuffle_num > 0: ds.shuffle(1000) - return wds.WebLoader( + return DataLoader( ds, num_workers=num_workers, batch_size=batch_size, diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 146057a..9202a02 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -21,7 +21,7 @@ import pytorch_warmup as warmup from ema_pytorch import EMA -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType import numpy as np @@ -76,6 +76,7 @@ def cast_torch_tensor(fn): def inner(model, *args, **kwargs): device = kwargs.pop('_device', next(model.parameters()).device) cast_device = kwargs.pop('_cast_device', True) + cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True) kwargs_keys = kwargs.keys() all_args = (*args, *kwargs.values()) @@ -85,6 +86,21 @@ def cast_torch_tensor(fn): if cast_device: all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) + if cast_deepspeed_precision: + try: + accelerator = model.accelerator + if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED: + cast_type_map = { + "fp16": torch.half, + "bf16": torch.bfloat16, + "no": torch.float + } + precision_type = cast_type_map[accelerator.mixed_precision] + all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) + except AttributeError: + # Then this model doesn't have an accelerator + pass + args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:] kwargs = dict(tuple(zip(kwargs_keys, kwargs_values))) @@ -446,6 +462,7 @@ class DecoderTrainer(nn.Module): self, decoder, accelerator = None, + dataloaders = None, use_ema = True, lr = 1e-4, wd = 1e-2, @@ -508,8 +525,21 @@ class DecoderTrainer(nn.Module): self.register_buffer('steps', torch.tensor([0] * self.num_unets)) - decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) + if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None: + # Then we need to make sure clip is using the correct precision or else deepspeed will error + cast_type_map = { + "fp16": torch.half, + "bf16": torch.bfloat16, + "no": torch.float + } + precision_type = cast_type_map[accelerator.mixed_precision] + assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip" + clip = decoder.clip + clip.to(precision_type) + decoder, train_loader, val_loader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders["train"], dataloaders["val"], *optimizers)) + self.train_loader = train_loader + self.val_loader = val_loader self.decoder = decoder # store optimizers @@ -675,6 +705,9 @@ class DecoderTrainer(nn.Module): total_loss = 0. + + using_amp = self.accelerator.mixed_precision != 'no' + for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): with self.accelerator.autocast(): loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs) diff --git a/train_decoder.py b/train_decoder.py index ee5c807..f9d36eb 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -274,6 +274,7 @@ def train( trainer = DecoderTrainer( decoder=decoder, accelerator=accelerator, + dataloaders=dataloaders, **kwargs ) @@ -284,7 +285,6 @@ def train( sample = 0 samples_seen = 0 val_sample = 0 - step = lambda: int(trainer.num_steps_taken(unet_number=1)) if tracker.can_recall: start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer) @@ -299,6 +299,8 @@ def train( if not exists(unet_training_mask): # Then the unet mask should be true for all unets in the decoder unet_training_mask = [True] * trainer.num_unets + first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask) + step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1)) assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}" accelerator.print(print_ribbon("Generating Example Data", repeat=40)) @@ -321,7 +323,7 @@ def train( last_snapshot = sample if next_task == 'train': - for i, (img, emb, txt) in enumerate(dataloaders["train"]): + for i, (img, emb, txt) in enumerate(trainer.train_loader): # We want to count the total number of samples across all processes sample_length_tensor[0] = len(img) all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this. @@ -414,7 +416,7 @@ def train( timer = Timer() accelerator.wait_for_everyone() i = 0 - for i, (img, emb, txt) in enumerate(dataloaders["val"]): + for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader val_sample_length_tensor[0] = len(img) all_samples = accelerator.gather(val_sample_length_tensor) total_samples = all_samples.sum().item() @@ -519,6 +521,20 @@ def initialize_training(config: TrainDecoderConfig, config_path): # Set up accelerator for configurable distributed training ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) + + if accelerator.num_processes > 1: + # We are using distributed training and want to immediately ensure all can connect + accelerator.print("Waiting for all processes to connect...") + accelerator.wait_for_everyone() + accelerator.print("All processes online and connected") + + # If we are in deepspeed fp16 mode, we must ensure learned variance is off + if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance: + raise ValueError("DeepSpeed fp16 mode does not support learned variance") + + if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED: + # This is an invalid configuration until we figure out how to handle this + raise ValueError("DeepSpeed does not support multi-node distributed training") # Set up data all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))