From fc954ee788cddecf245ff84735112635122f00a8 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 2 May 2022 07:57:28 -0700 Subject: [PATCH] fix calculation of adaptive weight for vit-vqgan, thanks to @CiaoHe --- dalle2_pytorch/vqgan_vae.py | 14 +++++++++++++- setup.py | 2 +- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py index bc203b5..7c97e6b 100644 --- a/dalle2_pytorch/vqgan_vae.py +++ b/dalle2_pytorch/vqgan_vae.py @@ -285,6 +285,10 @@ class ResnetEncDec(nn.Module): def get_encoded_fmap_size(self, image_size): return image_size // (2 ** self.layers) + @property + def last_dec_layer(self): + return self.decoders[-1].weight + def encode(self, x): for enc in self.encoders: x = enc(x) @@ -419,6 +423,10 @@ class ConvNextEncDec(nn.Module): def get_encoded_fmap_size(self, image_size): return image_size // (2 ** self.layers) + @property + def last_dec_layer(self): + return self.decoders[-1].weight + def encode(self, x): for enc in self.encoders: x = enc(x) @@ -606,6 +614,10 @@ class ViTEncDec(nn.Module): def get_encoded_fmap_size(self, image_size): return image_size // self.patch_size + @property + def last_dec_layer(self): + return self.decoder[-3][-1].weight + def encode(self, x): return self.encoder(x) @@ -843,7 +855,7 @@ class VQGanVAE(nn.Module): # calculate adaptive weight - last_dec_layer = self.decoders[-1].weight + last_dec_layer = self.enc_dec.last_dec_layer norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2) norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2) diff --git a/setup.py b/setup.py index 178a666..2309d4c 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.88', + version = '0.0.89', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',