Added deepspeed support (#195)

This commit is contained in:
Aidan Dempster
2022-07-08 21:18:08 -04:00
committed by GitHub
parent 3070610231
commit 5c520db825
3 changed files with 56 additions and 6 deletions

View File

@@ -21,7 +21,7 @@ import pytorch_warmup as warmup
from ema_pytorch import EMA
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
import numpy as np
@@ -76,6 +76,7 @@ def cast_torch_tensor(fn):
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
@@ -85,6 +86,21 @@ def cast_torch_tensor(fn):
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
if cast_deepspeed_precision:
try:
accelerator = model.accelerator
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
except AttributeError:
# Then this model doesn't have an accelerator
pass
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
@@ -446,6 +462,7 @@ class DecoderTrainer(nn.Module):
self,
decoder,
accelerator = None,
dataloaders = None,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
@@ -508,8 +525,21 @@ class DecoderTrainer(nn.Module):
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
# Then we need to make sure clip is using the correct precision or else deepspeed will error
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
clip = decoder.clip
clip.to(precision_type)
decoder, train_loader, val_loader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders["train"], dataloaders["val"], *optimizers))
self.train_loader = train_loader
self.val_loader = val_loader
self.decoder = decoder
# store optimizers
@@ -675,6 +705,9 @@ class DecoderTrainer(nn.Module):
total_loss = 0.
using_amp = self.accelerator.mixed_precision != 'no'
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)