need to keep track of training steps separately for each unet in decoder trainer

This commit is contained in:
Phil Wang
2022-07-05 15:17:59 -07:00
parent b9a908ff75
commit 3afdcdfe86
2 changed files with 12 additions and 5 deletions

View File

@@ -6,6 +6,7 @@ from functools import partial, wraps
from collections.abc import Iterable
import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import autocast, GradScaler
@@ -474,7 +475,7 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
@@ -491,7 +492,7 @@ class DecoderTrainer(nn.Module):
save_obj = dict(
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
version = __version__,
step = self.step.item(),
steps = self.steps.cpu(),
**kwargs
)
@@ -510,7 +511,7 @@ class DecoderTrainer(nn.Module):
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
self.steps.copy_(loaded_obj['steps'])
if only_model:
return loaded_obj
@@ -539,6 +540,12 @@ class DecoderTrainer(nn.Module):
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def increment_step(self, unet_number):
assert 1 <= unet_number <= self.num_unets
unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
def update(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
@@ -557,7 +564,7 @@ class DecoderTrainer(nn.Module):
ema_unet = self.ema_unets[index]
ema_unet.update()
self.step += 1
self.increment_step(unet_number)
@torch.no_grad()
@cast_torch_tensor

View File

@@ -1 +1 @@
__version__ = '0.16.2'
__version__ = '0.16.3'