diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py index c519a60..59a194f 100644 --- a/dalle2_pytorch/vqgan_vae.py +++ b/dalle2_pytorch/vqgan_vae.py @@ -435,12 +435,12 @@ class VQGanVAE(nn.Module): return fmap def decode(self, fmap): - fmap = self.vq(fmap) + fmap, indices, commit_loss = self.vq(fmap) for dec in self.decoders: fmap = dec(fmap) - return fmap + return fmap, indices, commit_loss def forward( self, @@ -453,9 +453,9 @@ class VQGanVAE(nn.Module): assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}' assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE' - fmap, indices, commit_loss = self.encode(img) + fmap = self.encode(img) - fmap = self.decode(fmap) + fmap, indices, commit_loss = self.decode(fmap) if not return_loss and not return_discr_loss: return fmap diff --git a/setup.py b/setup.py index fef0064..4914bed 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.35', + version = '0.0.36', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',