mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
Distributed Training of the Decoder (#121)
* Converted decoder trainer to use accelerate * Fixed issue where metric evaluation would hang on distributed mode * Implemented functional saving Loading still fails due to some issue with the optimizer * Fixed issue with loading decoders * Fixed issue with tracker config * Fixed issue with amp Updated logging to be more logical * Saving checkpoint now saves position in training as well Fixed an issue with running out of gpu space due to loading weights into the gpu twice * Fixed ema for distributed training * Fixed isue where get_pkg_version was reintroduced * Changed decoder trainer to upload config as a file Fixed issue where loading best would error
This commit is contained in:
@@ -574,6 +574,7 @@ def decoder_sample_in_chunks(fn):
|
||||
class DecoderTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
accelerator,
|
||||
decoder,
|
||||
use_ema = True,
|
||||
lr = 1e-4,
|
||||
@@ -588,8 +589,9 @@ class DecoderTrainer(nn.Module):
|
||||
assert isinstance(decoder, Decoder)
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
self.decoder = decoder
|
||||
self.num_unets = len(self.decoder.unets)
|
||||
self.accelerator = accelerator
|
||||
|
||||
self.num_unets = len(decoder.unets)
|
||||
|
||||
self.use_ema = use_ema
|
||||
self.ema_unets = nn.ModuleList([])
|
||||
@@ -601,7 +603,9 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
|
||||
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
|
||||
optimizers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
@@ -611,19 +615,20 @@ class DecoderTrainer(nn.Module):
|
||||
**kwargs
|
||||
)
|
||||
|
||||
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
|
||||
optimizers.append(optimizer)
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||
|
||||
scaler = GradScaler(enabled = amp)
|
||||
setattr(self, f'scaler{ind}', scaler)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
results = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
self.decoder = results.pop(0)
|
||||
for opt_ind in range(len(optimizers)):
|
||||
setattr(self, f'optim{opt_ind}', results.pop(0))
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
@@ -631,47 +636,42 @@ class DecoderTrainer(nn.Module):
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
save_obj = dict(
|
||||
model = self.decoder.state_dict(),
|
||||
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
|
||||
version = __version__,
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
scaler_key = f'scaler{ind}'
|
||||
optimizer_key = f'scaler{ind}'
|
||||
scaler = getattr(self, scaler_key)
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
|
||||
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||
|
||||
torch.save(save_obj, str(path))
|
||||
self.accelerator.save(save_obj, str(path))
|
||||
|
||||
def load(self, path, only_model = False, strict = True):
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path))
|
||||
loaded_obj = torch.load(str(path), map_location = 'cpu')
|
||||
|
||||
if version.parse(__version__) != loaded_obj['version']:
|
||||
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
||||
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||
|
||||
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
scaler_key = f'scaler{ind}'
|
||||
optimizer_key = f'scaler{ind}'
|
||||
scaler = getattr(self, scaler_key)
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
scaler.load_state_dict(loaded_obj[scaler_key])
|
||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
@@ -683,29 +683,18 @@ class DecoderTrainer(nn.Module):
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
def scale(self, loss, *, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
return scaler.scale(loss)
|
||||
|
||||
def update(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
unet = self.decoder.unets[index]
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
scaler.unscale_(optimizer)
|
||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if self.use_ema:
|
||||
@@ -718,15 +707,17 @@ class DecoderTrainer(nn.Module):
|
||||
@cast_torch_tensor
|
||||
@decoder_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
distributed = self.accelerator.num_processes > 1
|
||||
base_decoder = self.accelerator.unwrap_model(self.decoder)
|
||||
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
||||
return self.decoder.sample(*args, **kwargs)
|
||||
return base_decoder.sample(*args, **kwargs, distributed = distributed)
|
||||
|
||||
trainable_unets = self.decoder.unets
|
||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
|
||||
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
|
||||
output = self.decoder.sample(*args, **kwargs)
|
||||
output = base_decoder.sample(*args, **kwargs, distributed = distributed)
|
||||
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
base_decoder.unets = trainable_unets # restore original training unets
|
||||
|
||||
# cast the ema_model unets back to original device
|
||||
for ema in self.ema_unets:
|
||||
@@ -748,13 +739,14 @@ class DecoderTrainer(nn.Module):
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
# with autocast(enabled = self.amp):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if self.training:
|
||||
self.scale(loss, unet_number = unet_number).backward()
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
return total_loss
|
||||
|
||||
Reference in New Issue
Block a user