From 461347c1710dac7c206a4afabc114e428e51300c Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 22 Apr 2022 11:38:57 -0700 Subject: [PATCH] fix vqgan-vae for latent diffusion --- dalle2_pytorch/vqgan_vae.py | 8 ++++---- setup.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py index c519a60..59a194f 100644 --- a/dalle2_pytorch/vqgan_vae.py +++ b/dalle2_pytorch/vqgan_vae.py @@ -435,12 +435,12 @@ class VQGanVAE(nn.Module): return fmap def decode(self, fmap): - fmap = self.vq(fmap) + fmap, indices, commit_loss = self.vq(fmap) for dec in self.decoders: fmap = dec(fmap) - return fmap + return fmap, indices, commit_loss def forward( self, @@ -453,9 +453,9 @@ class VQGanVAE(nn.Module): assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}' assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE' - fmap, indices, commit_loss = self.encode(img) + fmap = self.encode(img) - fmap = self.decode(fmap) + fmap, indices, commit_loss = self.decode(fmap) if not return_loss and not return_discr_loss: return fmap diff --git a/setup.py b/setup.py index fef0064..4914bed 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.35', + version = '0.0.36', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',