mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
small cleanup
This commit is contained in:
@@ -627,10 +627,13 @@ class DecoderTrainer(nn.Module):
|
|||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
results = list(self.accelerator.prepare(decoder, *optimizers))
|
|
||||||
self.decoder = results.pop(0)
|
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||||
for opt_ind in range(len(optimizers)):
|
|
||||||
setattr(self, f'optim{opt_ind}', results.pop(0))
|
self.decoder = decoder
|
||||||
|
|
||||||
|
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
||||||
|
setattr(self, f'optim{opt_ind}', optimizer)
|
||||||
|
|
||||||
def save(self, path, overwrite = True, **kwargs):
|
def save(self, path, overwrite = True, **kwargs):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|||||||
Reference in New Issue
Block a user