From d1a697ac23b867bdacf4dded77725da53ce1afc8 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 30 Apr 2022 16:05:13 -0700 Subject: [PATCH] allows one to shortcut sampling at a specific unet number, if one were to be training in stages --- README.md | 2 +- dalle2_pytorch/dalle2_pytorch.py | 13 +++++++++++-- setup.py | 2 +- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 387837a..354a89e 100644 --- a/README.md +++ b/README.md @@ -783,7 +783,7 @@ for unet_number in (1, 2): # you can sample from the exponentially moving averaged unets as so mock_image_embed = torch.randn(4, 512).cuda() -images = decoder.sample(mock_image_embed, text = text) # (4, 3, 256, 256) +images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256) ``` ## CLI (wip) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index efdaf99..89ab4f4 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1540,7 +1540,13 @@ class Decoder(BaseGaussianDiffusion): @torch.no_grad() @eval_decorator - def sample(self, image_embed, text = None, cond_scale = 1.): + def sample( + self, + image_embed, + text = None, + cond_scale = 1., + stop_at_unet_number = None + ): batch_size = image_embed.shape[0] text_encodings = text_mask = None @@ -1552,7 +1558,7 @@ class Decoder(BaseGaussianDiffusion): img = None - for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): + for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context() @@ -1584,6 +1590,9 @@ class Decoder(BaseGaussianDiffusion): img = vae.decode(img) + if exists(stop_at_unet_number) and stop_at_unet_number == unet_number: + break + return img def forward( diff --git a/setup.py b/setup.py index c78d057..441e8b0 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.81', + version = '0.0.82', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',