small cleanup

This commit is contained in:
Phil Wang
2022-06-20 08:59:51 -07:00
parent 138079ca83
commit 67f0740777

View File

@@ -627,10 +627,13 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0.]))
results = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = results.pop(0) decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
for opt_ind in range(len(optimizers)):
setattr(self, f'optim{opt_ind}', results.pop(0)) self.decoder = decoder
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
setattr(self, f'optim{opt_ind}', optimizer)
def save(self, path, overwrite = True, **kwargs): def save(self, path, overwrite = True, **kwargs):
path = Path(path) path = Path(path)