mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Added deepspeed support (#195)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import webdataset as wds
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import fsspec
|
||||
import shutil
|
||||
@@ -255,7 +256,7 @@ def create_image_embedding_dataloader(
|
||||
)
|
||||
if shuffle_num is not None and shuffle_num > 0:
|
||||
ds.shuffle(1000)
|
||||
return wds.WebLoader(
|
||||
return DataLoader(
|
||||
ds,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
|
||||
@@ -21,7 +21,7 @@ import pytorch_warmup as warmup
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -76,6 +76,7 @@ def cast_torch_tensor(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
device = kwargs.pop('_device', next(model.parameters()).device)
|
||||
cast_device = kwargs.pop('_cast_device', True)
|
||||
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
|
||||
|
||||
kwargs_keys = kwargs.keys()
|
||||
all_args = (*args, *kwargs.values())
|
||||
@@ -85,6 +86,21 @@ def cast_torch_tensor(fn):
|
||||
if cast_device:
|
||||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
|
||||
if cast_deepspeed_precision:
|
||||
try:
|
||||
accelerator = model.accelerator
|
||||
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
cast_type_map = {
|
||||
"fp16": torch.half,
|
||||
"bf16": torch.bfloat16,
|
||||
"no": torch.float
|
||||
}
|
||||
precision_type = cast_type_map[accelerator.mixed_precision]
|
||||
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
except AttributeError:
|
||||
# Then this model doesn't have an accelerator
|
||||
pass
|
||||
|
||||
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||
|
||||
@@ -446,6 +462,7 @@ class DecoderTrainer(nn.Module):
|
||||
self,
|
||||
decoder,
|
||||
accelerator = None,
|
||||
dataloaders = None,
|
||||
use_ema = True,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
@@ -508,8 +525,21 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
|
||||
# Then we need to make sure clip is using the correct precision or else deepspeed will error
|
||||
cast_type_map = {
|
||||
"fp16": torch.half,
|
||||
"bf16": torch.bfloat16,
|
||||
"no": torch.float
|
||||
}
|
||||
precision_type = cast_type_map[accelerator.mixed_precision]
|
||||
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
|
||||
clip = decoder.clip
|
||||
clip.to(precision_type)
|
||||
decoder, train_loader, val_loader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders["train"], dataloaders["val"], *optimizers))
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.decoder = decoder
|
||||
|
||||
# store optimizers
|
||||
@@ -675,6 +705,9 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
total_loss = 0.
|
||||
|
||||
|
||||
using_amp = self.accelerator.mixed_precision != 'no'
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
|
||||
@@ -274,6 +274,7 @@ def train(
|
||||
trainer = DecoderTrainer(
|
||||
decoder=decoder,
|
||||
accelerator=accelerator,
|
||||
dataloaders=dataloaders,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -284,7 +285,6 @@ def train(
|
||||
sample = 0
|
||||
samples_seen = 0
|
||||
val_sample = 0
|
||||
step = lambda: int(trainer.num_steps_taken(unet_number=1))
|
||||
|
||||
if tracker.can_recall:
|
||||
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||
@@ -299,6 +299,8 @@ def train(
|
||||
if not exists(unet_training_mask):
|
||||
# Then the unet mask should be true for all unets in the decoder
|
||||
unet_training_mask = [True] * trainer.num_unets
|
||||
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
|
||||
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
|
||||
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||
|
||||
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
|
||||
@@ -321,7 +323,7 @@ def train(
|
||||
last_snapshot = sample
|
||||
|
||||
if next_task == 'train':
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||
for i, (img, emb, txt) in enumerate(trainer.train_loader):
|
||||
# We want to count the total number of samples across all processes
|
||||
sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||
@@ -414,7 +416,7 @@ def train(
|
||||
timer = Timer()
|
||||
accelerator.wait_for_everyone()
|
||||
i = 0
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
|
||||
val_sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
@@ -519,6 +521,20 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
# Set up accelerator for configurable distributed training
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
if accelerator.num_processes > 1:
|
||||
# We are using distributed training and want to immediately ensure all can connect
|
||||
accelerator.print("Waiting for all processes to connect...")
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print("All processes online and connected")
|
||||
|
||||
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
|
||||
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
|
||||
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
|
||||
|
||||
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
|
||||
# This is an invalid configuration until we figure out how to handle this
|
||||
raise ValueError("DeepSpeed does not support multi-node distributed training")
|
||||
|
||||
# Set up data
|
||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||
|
||||
Reference in New Issue
Block a user