mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
expose num_steps_taken helper method on trainer to retrieve number of training steps of each unet
This commit is contained in:
@@ -527,6 +527,17 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.warmup_schedulers = warmup_schedulers
|
||||
|
||||
def validate_and_return_unet_number(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||
return unet_number
|
||||
|
||||
def num_steps_taken(self, unet_number = None):
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
return self.steps[unet_number - 1].item()
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
@@ -595,10 +606,7 @@ class DecoderTrainer(nn.Module):
|
||||
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
|
||||
|
||||
def update(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
index = unet_number - 1
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
@@ -664,8 +672,7 @@ class DecoderTrainer(nn.Module):
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
|
||||
total_loss = 0.
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.16.18'
|
||||
__version__ = '0.16.19'
|
||||
|
||||
Reference in New Issue
Block a user