From 8260fc933a9a5118e18f209314ecfad246a42454 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 30 Apr 2022 15:10:25 -0700 Subject: [PATCH] allows one to shortcut sampling at a specific unet number, if one were to be training in stages --- dalle2_pytorch/dalle2_pytorch.py | 13 +++++++++++-- setup.py | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index efdaf99..89ab4f4 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1540,7 +1540,13 @@ class Decoder(BaseGaussianDiffusion): @torch.no_grad() @eval_decorator - def sample(self, image_embed, text = None, cond_scale = 1.): + def sample( + self, + image_embed, + text = None, + cond_scale = 1., + stop_at_unet_number = None + ): batch_size = image_embed.shape[0] text_encodings = text_mask = None @@ -1552,7 +1558,7 @@ class Decoder(BaseGaussianDiffusion): img = None - for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): + for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context() @@ -1584,6 +1590,9 @@ class Decoder(BaseGaussianDiffusion): img = vae.decode(img) + if exists(stop_at_unet_number) and stop_at_unet_number == unet_number: + break + return img def forward( diff --git a/setup.py b/setup.py index c78d057..441e8b0 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.81', + version = '0.0.82', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',