From 13382885d94c45cb6b4fee0c788d3d6f99a1086d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 16 May 2022 12:57:31 -0700 Subject: [PATCH] final update to dalle2 repository for a while - sampling from prior in chunks automatically with max_batch_size keyword given --- README.md | 6 +++--- dalle2_pytorch/trainer.py | 12 ++++++++++++ setup.py | 2 +- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 69c94ff..5261be1 100644 --- a/README.md +++ b/README.md @@ -820,8 +820,8 @@ clip = CLIP( # mock data -text = torch.randint(0, 49408, (32, 256)).cuda() -images = torch.randn(32, 3, 256, 256).cuda() +text = torch.randint(0, 49408, (512, 256)).cuda() +images = torch.randn(512, 3, 256, 256).cuda() # prior networks (with transformer) @@ -854,7 +854,7 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th # after much of the above three lines in a loop # you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior -image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings +image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings ``` ## Bonus diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 9d96654..090608e 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -235,6 +235,16 @@ class EMA(nn.Module): # diffusion prior trainer +def prior_sample_in_chunks(fn): + @wraps(fn) + def inner(self, *args, max_batch_size = None, **kwargs): + if not exists(max_batch_size): + return fn(self, *args, **kwargs) + + outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)] + return torch.cat(outputs, dim = 0) + return inner + class DiffusionPriorTrainer(nn.Module): def __init__( self, @@ -295,11 +305,13 @@ class DiffusionPriorTrainer(nn.Module): @torch.no_grad() @cast_torch_tensor + @prior_sample_in_chunks def p_sample_loop(self, *args, **kwargs): return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) @torch.no_grad() @cast_torch_tensor + @prior_sample_in_chunks def sample(self, *args, **kwargs): return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs) diff --git a/setup.py b/setup.py index f2ac23f..0ccf36e 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.46', + version = '0.3.0', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',