Added deepspeed support (#195)

This commit is contained in:
Aidan Dempster
2022-07-08 21:18:08 -04:00
committed by GitHub
parent 3070610231
commit 5c520db825
3 changed files with 56 additions and 6 deletions

View File

@@ -1,6 +1,7 @@
import os import os
import webdataset as wds import webdataset as wds
import torch import torch
from torch.utils.data import DataLoader
import numpy as np import numpy as np
import fsspec import fsspec
import shutil import shutil
@@ -255,7 +256,7 @@ def create_image_embedding_dataloader(
) )
if shuffle_num is not None and shuffle_num > 0: if shuffle_num is not None and shuffle_num > 0:
ds.shuffle(1000) ds.shuffle(1000)
return wds.WebLoader( return DataLoader(
ds, ds,
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,

View File

@@ -21,7 +21,7 @@ import pytorch_warmup as warmup
from ema_pytorch import EMA from ema_pytorch import EMA
from accelerate import Accelerator from accelerate import Accelerator, DistributedType
import numpy as np import numpy as np
@@ -76,6 +76,7 @@ def cast_torch_tensor(fn):
def inner(model, *args, **kwargs): def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device) device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True) cast_device = kwargs.pop('_cast_device', True)
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
kwargs_keys = kwargs.keys() kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values()) all_args = (*args, *kwargs.values())
@@ -85,6 +86,21 @@ def cast_torch_tensor(fn):
if cast_device: if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
if cast_deepspeed_precision:
try:
accelerator = model.accelerator
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
except AttributeError:
# Then this model doesn't have an accelerator
pass
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:] args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values))) kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
@@ -446,6 +462,7 @@ class DecoderTrainer(nn.Module):
self, self,
decoder, decoder,
accelerator = None, accelerator = None,
dataloaders = None,
use_ema = True, use_ema = True,
lr = 1e-4, lr = 1e-4,
wd = 1e-2, wd = 1e-2,
@@ -508,8 +525,21 @@ class DecoderTrainer(nn.Module):
self.register_buffer('steps', torch.tensor([0] * self.num_unets)) self.register_buffer('steps', torch.tensor([0] * self.num_unets))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
# Then we need to make sure clip is using the correct precision or else deepspeed will error
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
clip = decoder.clip
clip.to(precision_type)
decoder, train_loader, val_loader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders["train"], dataloaders["val"], *optimizers))
self.train_loader = train_loader
self.val_loader = val_loader
self.decoder = decoder self.decoder = decoder
# store optimizers # store optimizers
@@ -675,6 +705,9 @@ class DecoderTrainer(nn.Module):
total_loss = 0. total_loss = 0.
using_amp = self.accelerator.mixed_precision != 'no'
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast(): with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs) loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)

View File

@@ -274,6 +274,7 @@ def train(
trainer = DecoderTrainer( trainer = DecoderTrainer(
decoder=decoder, decoder=decoder,
accelerator=accelerator, accelerator=accelerator,
dataloaders=dataloaders,
**kwargs **kwargs
) )
@@ -284,7 +285,6 @@ def train(
sample = 0 sample = 0
samples_seen = 0 samples_seen = 0
val_sample = 0 val_sample = 0
step = lambda: int(trainer.num_steps_taken(unet_number=1))
if tracker.can_recall: if tracker.can_recall:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer) start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
@@ -299,6 +299,8 @@ def train(
if not exists(unet_training_mask): if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder # Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets unet_training_mask = [True] * trainer.num_unets
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}" assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
accelerator.print(print_ribbon("Generating Example Data", repeat=40)) accelerator.print(print_ribbon("Generating Example Data", repeat=40))
@@ -321,7 +323,7 @@ def train(
last_snapshot = sample last_snapshot = sample
if next_task == 'train': if next_task == 'train':
for i, (img, emb, txt) in enumerate(dataloaders["train"]): for i, (img, emb, txt) in enumerate(trainer.train_loader):
# We want to count the total number of samples across all processes # We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img) sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this. all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
@@ -414,7 +416,7 @@ def train(
timer = Timer() timer = Timer()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
i = 0 i = 0
for i, (img, emb, txt) in enumerate(dataloaders["val"]): for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
val_sample_length_tensor[0] = len(img) val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor) all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item() total_samples = all_samples.sum().item()
@@ -520,6 +522,20 @@ def initialize_training(config: TrainDecoderConfig, config_path):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
if accelerator.num_processes > 1:
# We are using distributed training and want to immediately ensure all can connect
accelerator.print("Waiting for all processes to connect...")
accelerator.wait_for_everyone()
accelerator.print("All processes online and connected")
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
# This is an invalid configuration until we figure out how to handle this
raise ValueError("DeepSpeed does not support multi-node distributed training")
# Set up data # Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1)) all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
world_size = accelerator.num_processes world_size = accelerator.num_processes