Distributed Training of the Decoder (#121)

* Converted decoder trainer to use accelerate

* Fixed issue where metric evaluation would hang on distributed mode

* Implemented functional saving
Loading still fails due to some issue with the optimizer

* Fixed issue with loading decoders

* Fixed issue with tracker config

* Fixed issue with amp
Updated logging to be more logical

* Saving checkpoint now saves position in training as well
Fixed an issue with running out of gpu space due to loading weights into the gpu twice

* Fixed ema for distributed training

* Fixed isue where get_pkg_version was reintroduced

* Changed decoder trainer to upload config as a file

Fixed issue where loading best would error
This commit is contained in:
Aidan Dempster
2022-06-19 12:25:54 -04:00
committed by GitHub
parent e37072a48c
commit 58892135d9
7 changed files with 331 additions and 207 deletions

View File

@@ -2099,7 +2099,8 @@ class Decoder(BaseGaussianDiffusion):
text_encodings = None,
batch_size = 1,
cond_scale = 1.,
stop_at_unet_number = None
stop_at_unet_number = None,
distributed = False,
):
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
@@ -2118,7 +2119,7 @@ class Decoder(BaseGaussianDiffusion):
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
with context:
lowres_cond_img = None

View File

@@ -164,9 +164,6 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
self.append(wds.split_by_node)
self.append(wds.split_by_worker)
self.append(wds.tarfile_to_samples(handler=handler))
self.append(wds.decode("pilrgb", handler=handler))
if embedding_folder_url is not None:

View File

@@ -17,15 +17,15 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val):
return val is not None
# load state dict functions
# load file functions
def load_wandb_state_dict(run_path, file_path, **kwargs):
def load_wandb_file(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return torch.load(file_reference.name)
return file_reference.name
def load_local_state_dict(file_path, **kwargs):
return torch.load(file_path)
def load_local_file(file_path, **kwargs):
return file_path
# base class
@@ -55,12 +55,43 @@ class BaseTracker(nn.Module):
"""
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
if recall_source == 'wandb':
return load_wandb_state_dict(*args, **kwargs)
return torch.load(load_wandb_file(*args, **kwargs))
elif recall_source == 'local':
return load_local_state_dict(*args, **kwargs)
return torch.load(load_local_file(*args, **kwargs))
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
def save_file(self, file_path, **kwargs):
raise NotImplementedError
def recall_file(self, recall_source, *args, **kwargs):
if recall_source == 'wandb':
return load_wandb_file(*args, **kwargs)
elif recall_source == 'local':
return load_local_file(*args, **kwargs)
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
# Tracker that no-ops all calls except for recall
class DummyTracker(BaseTracker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def init(self, config, **kwargs):
pass
def log(self, log, **kwargs):
pass
def log_images(self, images, **kwargs):
pass
def save_state_dict(self, state_dict, relative_path, **kwargs):
pass
def save_file(self, file_path, **kwargs):
pass
# basic stdout class
@@ -76,6 +107,10 @@ class ConsoleTracker(BaseTracker):
def save_state_dict(self, state_dict, relative_path, **kwargs):
torch.save(state_dict, str(self.data_path / relative_path))
def save_file(self, file_path, **kwargs):
# This is a no-op for local file systems since it is already saved locally
pass
# basic wandb class
@@ -107,3 +142,11 @@ class WandbTracker(BaseTracker):
full_path = str(self.data_path / relative_path)
torch.save(state_dict, full_path)
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
def save_file(self, file_path, base_path=None, **kwargs):
"""
Uploads a file from disk to wandb
"""
if base_path is None:
base_path = self.data_path
self.wandb.save(str(file_path), base_path = str(base_path))

View File

@@ -261,6 +261,7 @@ class TrainDecoderConfig(BaseModel):
evaluate: DecoderEvaluateConfig
tracker: TrackerConfig
load: DecoderLoadConfig
seed: int = 0
@classmethod
def from_json_path(cls, json_path):

View File

@@ -574,6 +574,7 @@ def decoder_sample_in_chunks(fn):
class DecoderTrainer(nn.Module):
def __init__(
self,
accelerator,
decoder,
use_ema = True,
lr = 1e-4,
@@ -588,8 +589,9 @@ class DecoderTrainer(nn.Module):
assert isinstance(decoder, Decoder)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.decoder = decoder
self.num_unets = len(self.decoder.unets)
self.accelerator = accelerator
self.num_unets = len(decoder.unets)
self.use_ema = use_ema
self.ema_unets = nn.ModuleList([])
@@ -601,7 +603,9 @@ class DecoderTrainer(nn.Module):
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizers = []
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
@@ -611,19 +615,20 @@ class DecoderTrainer(nn.Module):
**kwargs
)
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
optimizers.append(optimizer)
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
results = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = results.pop(0)
for opt_ind in range(len(optimizers)):
setattr(self, f'optim{opt_ind}', results.pop(0))
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
@@ -631,47 +636,42 @@ class DecoderTrainer(nn.Module):
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
model = self.decoder.state_dict(),
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
version = __version__,
step = self.step.item(),
**kwargs
)
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
torch.save(save_obj, str(path))
self.accelerator.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
loaded_obj = torch.load(str(path), map_location = 'cpu')
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
if version.parse(__version__) != version.parse(loaded_obj['version']):
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return loaded_obj
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
scaler.load_state_dict(loaded_obj[scaler_key])
optimizer.load_state_dict(loaded_obj[optimizer_key])
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
if self.use_ema:
assert 'ema' in loaded_obj
@@ -683,29 +683,18 @@ class DecoderTrainer(nn.Module):
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
index = unet_number - 1
unet = self.decoder.unets[index]
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
if exists(self.max_grad_norm):
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
scaler.step(optimizer)
scaler.update()
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
optimizer.step()
optimizer.zero_grad()
if self.use_ema:
@@ -718,15 +707,17 @@ class DecoderTrainer(nn.Module):
@cast_torch_tensor
@decoder_sample_in_chunks
def sample(self, *args, **kwargs):
distributed = self.accelerator.num_processes > 1
base_decoder = self.accelerator.unwrap_model(self.decoder)
if kwargs.pop('use_non_ema', False) or not self.use_ema:
return self.decoder.sample(*args, **kwargs)
return base_decoder.sample(*args, **kwargs, distributed = distributed)
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
output = base_decoder.sample(*args, **kwargs, distributed = distributed)
self.decoder.unets = trainable_unets # restore original training unets
base_decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets:
@@ -748,13 +739,14 @@ class DecoderTrainer(nn.Module):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
# with autocast(enabled = self.amp):
with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
if self.training:
self.scale(loss, unet_number = unet_number).backward()
self.accelerator.backward(loss)
return total_loss

View File

@@ -1,4 +1,5 @@
import time
import importlib
# time helpers