From ebe01749ed0a48aa77e236eb609440db0944eada Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 30 Apr 2022 14:55:34 -0700 Subject: [PATCH] DecoderTrainer sample method uses the exponentially moving averaged --- README.md | 8 +++++++- dalle2_pytorch/train.py | 16 ++++++++++++++++ setup.py | 2 +- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a7686d6..387837a 100644 --- a/README.md +++ b/README.md @@ -760,7 +760,7 @@ decoder = Decoder( unet = (unet1, unet2), image_sizes = (128, 256), clip = clip, - timesteps = 1, + timesteps = 1000, condition_on_text_encodings = True ).cuda() @@ -778,6 +778,12 @@ for unet_number in (1, 2): loss.backward() decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average + +# after much training +# you can sample from the exponentially moving averaged unets as so + +mock_image_embed = torch.randn(4, 512).cuda() +images = decoder.sample(mock_image_embed, text = text) # (4, 3, 256, 256) ``` ## CLI (wip) diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index 7f3a657..0868182 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -144,6 +144,10 @@ class DecoderTrainer(nn.Module): self.max_grad_norm = max_grad_norm + @property + def unets(self): + return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) + def scale(self, loss, *, unet_number): assert 1 <= unet_number <= self.num_unets index = unet_number - 1 @@ -169,6 +173,18 @@ class DecoderTrainer(nn.Module): ema_unet = self.ema_unets[index] ema_unet.update() + @torch.no_grad() + def sample(self, *args, **kwargs): + if self.use_ema: + trainable_unets = self.decoder.unets + self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling + + output = self.decoder.sample(*args, **kwargs) + + if self.use_ema: + self.decoder.unets = trainable_unets # restore original training unets + return output + def forward( self, x, diff --git a/setup.py b/setup.py index 0f56c1f..c78d057 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.80', + version = '0.0.81', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',