diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index bc68caa..83aa3d8 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -627,10 +627,13 @@ class DecoderTrainer(nn.Module): self.max_grad_norm = max_grad_norm self.register_buffer('step', torch.tensor([0.])) - results = list(self.accelerator.prepare(decoder, *optimizers)) - self.decoder = results.pop(0) - for opt_ind in range(len(optimizers)): - setattr(self, f'optim{opt_ind}', results.pop(0)) + + decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) + + self.decoder = decoder + + for opt_ind, optimizer in zip(range(len(optimizers)), optimizers): + setattr(self, f'optim{opt_ind}', optimizer) def save(self, path, overwrite = True, **kwargs): path = Path(path)