small cleanup

This commit is contained in:
Phil Wang
2022-06-20 08:59:51 -07:00
parent 138079ca83
commit 67f0740777

View File

@@ -627,10 +627,13 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
results = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = results.pop(0)
for opt_ind in range(len(optimizers)):
setattr(self, f'optim{opt_ind}', results.pop(0))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = decoder
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
setattr(self, f'optim{opt_ind}', optimizer)
def save(self, path, overwrite = True, **kwargs):
path = Path(path)