From 67f07407773952f411c770870ca95f150368756a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 20 Jun 2022 08:59:51 -0700 Subject: [PATCH] small cleanup --- dalle2_pytorch/trainer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index bc68caa..83aa3d8 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -627,10 +627,13 @@ class DecoderTrainer(nn.Module): self.max_grad_norm = max_grad_norm self.register_buffer('step', torch.tensor([0.])) - results = list(self.accelerator.prepare(decoder, *optimizers)) - self.decoder = results.pop(0) - for opt_ind in range(len(optimizers)): - setattr(self, f'optim{opt_ind}', results.pop(0)) + + decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) + + self.decoder = decoder + + for opt_ind, optimizer in zip(range(len(optimizers)), optimizers): + setattr(self, f'optim{opt_ind}', optimizer) def save(self, path, overwrite = True, **kwargs): path = Path(path)