mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix vqgan-vae for latent diffusion
This commit is contained in:
@@ -435,12 +435,12 @@ class VQGanVAE(nn.Module):
|
||||
return fmap
|
||||
|
||||
def decode(self, fmap):
|
||||
fmap = self.vq(fmap)
|
||||
fmap, indices, commit_loss = self.vq(fmap)
|
||||
|
||||
for dec in self.decoders:
|
||||
fmap = dec(fmap)
|
||||
|
||||
return fmap
|
||||
return fmap, indices, commit_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -453,9 +453,9 @@ class VQGanVAE(nn.Module):
|
||||
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
|
||||
|
||||
fmap, indices, commit_loss = self.encode(img)
|
||||
fmap = self.encode(img)
|
||||
|
||||
fmap = self.decode(fmap)
|
||||
fmap, indices, commit_loss = self.decode(fmap)
|
||||
|
||||
if not return_loss and not return_discr_loss:
|
||||
return fmap
|
||||
|
||||
Reference in New Issue
Block a user