final update to dalle2 repository for a while - sampling from prior in chunks automatically with max_batch_size keyword given

This commit is contained in:
Phil Wang
2022-05-16 12:57:31 -07:00
parent c3d4a7ffe4
commit 13382885d9
3 changed files with 16 additions and 4 deletions

View File

@@ -820,8 +820,8 @@ clip = CLIP(
# mock data # mock data
text = torch.randint(0, 49408, (32, 256)).cuda() text = torch.randint(0, 49408, (512, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda() images = torch.randn(512, 3, 256, 256).cuda()
# prior networks (with transformer) # prior networks (with transformer)
@@ -854,7 +854,7 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
# after much of the above three lines in a loop # after much of the above three lines in a loop
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior # you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings
``` ```
## Bonus ## Bonus

View File

@@ -235,6 +235,16 @@ class EMA(nn.Module):
# diffusion prior trainer # diffusion prior trainer
def prior_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
return torch.cat(outputs, dim = 0)
return inner
class DiffusionPriorTrainer(nn.Module): class DiffusionPriorTrainer(nn.Module):
def __init__( def __init__(
self, self,
@@ -295,11 +305,13 @@ class DiffusionPriorTrainer(nn.Module):
@torch.no_grad() @torch.no_grad()
@cast_torch_tensor @cast_torch_tensor
@prior_sample_in_chunks
def p_sample_loop(self, *args, **kwargs): def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
@cast_torch_tensor @cast_torch_tensor
@prior_sample_in_chunks
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs) return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.46', version = '0.3.0',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',