mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
need to keep track of training steps separately for each unet in decoder trainer
This commit is contained in:
@@ -6,6 +6,7 @@ from functools import partial, wraps
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
|
|
||||||
@@ -474,7 +475,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
||||||
|
|
||||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||||
|
|
||||||
@@ -491,7 +492,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
save_obj = dict(
|
save_obj = dict(
|
||||||
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
|
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
|
||||||
version = __version__,
|
version = __version__,
|
||||||
step = self.step.item(),
|
steps = self.steps.cpu(),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -510,7 +511,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||||
|
|
||||||
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
|
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
self.steps.copy_(loaded_obj['steps'])
|
||||||
|
|
||||||
if only_model:
|
if only_model:
|
||||||
return loaded_obj
|
return loaded_obj
|
||||||
@@ -539,6 +540,12 @@ class DecoderTrainer(nn.Module):
|
|||||||
def unets(self):
|
def unets(self):
|
||||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||||
|
|
||||||
|
def increment_step(self, unet_number):
|
||||||
|
assert 1 <= unet_number <= self.num_unets
|
||||||
|
|
||||||
|
unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)
|
||||||
|
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
|
||||||
|
|
||||||
def update(self, unet_number = None):
|
def update(self, unet_number = None):
|
||||||
if self.num_unets == 1:
|
if self.num_unets == 1:
|
||||||
unet_number = default(unet_number, 1)
|
unet_number = default(unet_number, 1)
|
||||||
@@ -557,7 +564,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
ema_unet = self.ema_unets[index]
|
ema_unet = self.ema_unets[index]
|
||||||
ema_unet.update()
|
ema_unet.update()
|
||||||
|
|
||||||
self.step += 1
|
self.increment_step(unet_number)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.16.2'
|
__version__ = '0.16.3'
|
||||||
|
|||||||
Reference in New Issue
Block a user