fix vqgan-vae for latent diffusion

This commit is contained in:
Phil Wang
2022-04-22 11:38:57 -07:00
parent 46cef31c86
commit 461347c171
2 changed files with 5 additions and 5 deletions

View File

@@ -435,12 +435,12 @@ class VQGanVAE(nn.Module):
return fmap
def decode(self, fmap):
fmap = self.vq(fmap)
fmap, indices, commit_loss = self.vq(fmap)
for dec in self.decoders:
fmap = dec(fmap)
return fmap
return fmap, indices, commit_loss
def forward(
self,
@@ -453,9 +453,9 @@ class VQGanVAE(nn.Module):
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
fmap, indices, commit_loss = self.encode(img)
fmap = self.encode(img)
fmap = self.decode(fmap)
fmap, indices, commit_loss = self.decode(fmap)
if not return_loss and not return_discr_loss:
return fmap