diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 6467eba..4afb6a4 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -6,6 +6,7 @@ from functools import partial, wraps from collections.abc import Iterable import torch +import torch.nn.functional as F from torch import nn from torch.cuda.amp import autocast, GradScaler @@ -474,7 +475,7 @@ class DecoderTrainer(nn.Module): self.max_grad_norm = max_grad_norm - self.register_buffer('step', torch.tensor([0.])) + self.register_buffer('steps', torch.tensor([0] * self.num_unets)) decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) @@ -491,7 +492,7 @@ class DecoderTrainer(nn.Module): save_obj = dict( model = self.accelerator.unwrap_model(self.decoder).state_dict(), version = __version__, - step = self.step.item(), + steps = self.steps.cpu(), **kwargs ) @@ -510,7 +511,7 @@ class DecoderTrainer(nn.Module): self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}') self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict) - self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) + self.steps.copy_(loaded_obj['steps']) if only_model: return loaded_obj @@ -539,6 +540,12 @@ class DecoderTrainer(nn.Module): def unets(self): return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) + def increment_step(self, unet_number): + assert 1 <= unet_number <= self.num_unets + + unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device) + self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps)) + def update(self, unet_number = None): if self.num_unets == 1: unet_number = default(unet_number, 1) @@ -557,7 +564,7 @@ class DecoderTrainer(nn.Module): ema_unet = self.ema_unets[index] ema_unet.update() - self.step += 1 + self.increment_step(unet_number) @torch.no_grad() @cast_torch_tensor diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index e935064..66e314a 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.2' +__version__ = '0.16.3'