Prior updates (#211)

* update configs for prior

add prior warmup to config

update example prior config

* update prior trainer & script

add deepspeed amp & warmup

adopt full accelerator support

reload at sample point

finish epoch resume code

* update tracker save method for prior

* helper functions for prior_loader
This commit is contained in:
zion
2022-07-20 18:04:26 -07:00
committed by GitHub
parent 06c65b60d2
commit f9423d308b
6 changed files with 676 additions and 352 deletions

View File

@@ -174,27 +174,21 @@ class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
accelerator,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
eps = 1e-6,
max_grad_norm = None,
amp = False,
group_wd_params = True,
device = None,
accelerator = None,
verbose = True,
warmup_steps = 1,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
assert isinstance(accelerator, Accelerator)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
# verbosity
self.verbose = verbose
# assign some helpful member vars
self.accelerator = accelerator
@@ -202,23 +196,31 @@ class DiffusionPriorTrainer(nn.Module):
# setting the device
if not exists(accelerator) and not exists(device):
diffusion_prior_device = next(diffusion_prior.parameters()).device
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
self.device = diffusion_prior_device
else:
self.device = accelerator.device if exists(accelerator) else device
diffusion_prior.to(self.device)
self.device = accelerator.device
diffusion_prior.to(self.device)
# save model
self.diffusion_prior = diffusion_prior
# optimizer and mixed precision stuff
# mixed precision checks
self.amp = amp
if (
exists(self.accelerator)
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
and self.diffusion_prior.clip is not None
):
# Then we need to make sure clip is using the correct precision or else deepspeed will error
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
self.diffusion_prior.clip.to(precision_type)
self.scaler = GradScaler(enabled = amp)
# optimizer stuff
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
@@ -227,17 +229,21 @@ class DiffusionPriorTrainer(nn.Module):
**self.optim_kwargs,
**kwargs
)
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
# distribute the model if using HFA
if exists(self.accelerator):
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler)
# exponential moving average stuff
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs)
# gradient clipping if needed
@@ -247,67 +253,24 @@ class DiffusionPriorTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0], device = self.device))
# accelerator wrappers
def print(self, msg):
if not self.verbose:
return
if exists(self.accelerator):
self.accelerator.print(msg)
else:
print(msg)
def unwrap_model(self, model):
if exists(self.accelerator):
return self.accelerator.unwrap_model(model)
else:
return model
def wait_for_everyone(self):
if exists(self.accelerator):
self.accelerator.wait_for_everyone()
def is_main_process(self):
if exists(self.accelerator):
return self.accelerator.is_main_process
else:
return True
def clip_grad_norm_(self, *args):
if exists(self.accelerator):
return self.accelerator.clip_grad_norm_(*args)
else:
return torch.nn.utils.clip_grad_norm_(*args)
def backprop(self, x):
if exists(self.accelerator):
self.accelerator.backward(x)
else:
try:
x.backward()
except Exception as e:
self.print(f"Caught error in backprop call: {e}")
# utility
def save(self, path, overwrite = True, **kwargs):
# ensure we sync gradients before continuing
self.wait_for_everyone()
# only save on the main process
if self.is_main_process():
self.print(f"Saving checkpoint at step: {self.step.item()}")
if self.accelerator.is_main_process:
print(f"Saving checkpoint at step: {self.step.item()}")
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
# FIXME: LambdaLR can't be saved due to pickling issues
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
warmup_scheduler = self.warmup_scheduler,
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
version = version.parse(__version__),
step = self.step.item(),
step = self.step,
**kwargs
)
@@ -320,14 +283,14 @@ class DiffusionPriorTrainer(nn.Module):
torch.save(save_obj, str(path))
def load(self, path, overwrite_lr = True, strict = True):
def load(self, path_or_state, overwrite_lr = True, strict = True):
"""
Load a checkpoint of a diffusion prior trainer.
Will load the entire trainer, including the optimizer and EMA.
Params:
- path (str): a path to the DiffusionPriorTrainer checkpoint file
- path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
@@ -336,56 +299,56 @@ class DiffusionPriorTrainer(nn.Module):
"""
# all processes need to load checkpoint. no restriction here
path = Path(path)
assert path.exists()
if isinstance(path_or_state, str):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path), map_location=self.device)
loaded_obj = torch.load(str(path), map_location=self.device)
elif isinstance(path_or_state, dict):
loaded_obj = path_or_state
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
# unwrap the model when loading from checkpoint
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
self.scaler.load_state_dict(loaded_obj['scaler'])
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
self.optimizer.load_state_dict(loaded_obj['optimizer'])
# set warmupstep
if exists(self.warmup_scheduler):
self.warmup_scheduler.last_step = self.step.item()
# ensure new lr is used if different from old one
if overwrite_lr:
new_lr = self.optim_kwargs["lr"]
self.print(f"Overriding LR to be {new_lr}")
for group in self.optimizer.param_groups:
group["lr"] = new_lr
group["lr"] = new_lr if group["lr"] > 0.0 else 0.0
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
# below might not be necessary, but I had a suspicion that this wasn't being loaded correctly
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
# sync and inform
self.wait_for_everyone()
self.print(f"Loaded model")
return loaded_obj
# model functionality
def update(self):
# only continue with updates until all ranks finish
self.wait_for_everyone()
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
# utilize HFA clipping where applicable
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
if not self.accelerator.optimizer_step_was_skipped:
with self.warmup_scheduler.dampening():
self.scheduler.step()
if self.use_ema:
self.ema_diffusion_prior.update()
@@ -414,7 +377,7 @@ class DiffusionPriorTrainer(nn.Module):
@cast_torch_tensor
@prior_sample_in_chunks
def embed_text(self, *args, **kwargs):
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
@cast_torch_tensor
def forward(
@@ -426,16 +389,14 @@ class DiffusionPriorTrainer(nn.Module):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
with self.accelerator.autocast():
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
# backprop with accelerate if applicable
if self.training:
self.backprop(self.scaler.scale(loss))
self.accelerator.backward(loss)
return total_loss