mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
small cleanup
This commit is contained in:
@@ -627,10 +627,13 @@ class DecoderTrainer(nn.Module):
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
results = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
self.decoder = results.pop(0)
|
||||
for opt_ind in range(len(optimizers)):
|
||||
setattr(self, f'optim{opt_ind}', results.pop(0))
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
||||
setattr(self, f'optim{opt_ind}', optimizer)
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
|
||||
Reference in New Issue
Block a user