use a decorator and take care of sampling in chunks (max_batch_size keyword), in case one is sampling a huge grid of images

This commit is contained in:
Phil Wang
2022-05-16 12:34:28 -07:00
parent 5562ec6be2
commit 164d9be444
2 changed files with 30 additions and 7 deletions

View File

@@ -47,6 +47,14 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
# decorators
def cast_torch_tensor(fn):
@@ -322,6 +330,22 @@ class DiffusionPriorTrainer(nn.Module):
# decoder trainer
def decoder_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
if self.decoder.unconditional:
batch_size = kwargs.get('batch_size')
batch_sizes = num_to_groups(batch_size, max_batch_size)
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
else:
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
return torch.cat(outputs, dim = 0)
return inner
class DecoderTrainer(nn.Module):
def __init__(
self,
@@ -411,18 +435,17 @@ class DecoderTrainer(nn.Module):
@torch.no_grad()
@cast_torch_tensor
@decoder_sample_in_chunks
def sample(self, *args, **kwargs):
if kwargs.pop('use_non_ema', False):
if kwargs.pop('use_non_ema', False) or not self.use_ema:
return self.decoder.sample(*args, **kwargs)
if self.use_ema:
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
self.decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets: