fix calculation of adaptive weight for vit-vqgan, thanks to @CiaoHe

This commit is contained in:
Phil Wang
2022-05-02 07:57:28 -07:00
parent c1db2753f5
commit fc954ee788
2 changed files with 14 additions and 2 deletions

View File

@@ -285,6 +285,10 @@ class ResnetEncDec(nn.Module):
def get_encoded_fmap_size(self, image_size): def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers) return image_size // (2 ** self.layers)
@property
def last_dec_layer(self):
return self.decoders[-1].weight
def encode(self, x): def encode(self, x):
for enc in self.encoders: for enc in self.encoders:
x = enc(x) x = enc(x)
@@ -419,6 +423,10 @@ class ConvNextEncDec(nn.Module):
def get_encoded_fmap_size(self, image_size): def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers) return image_size // (2 ** self.layers)
@property
def last_dec_layer(self):
return self.decoders[-1].weight
def encode(self, x): def encode(self, x):
for enc in self.encoders: for enc in self.encoders:
x = enc(x) x = enc(x)
@@ -606,6 +614,10 @@ class ViTEncDec(nn.Module):
def get_encoded_fmap_size(self, image_size): def get_encoded_fmap_size(self, image_size):
return image_size // self.patch_size return image_size // self.patch_size
@property
def last_dec_layer(self):
return self.decoder[-3][-1].weight
def encode(self, x): def encode(self, x):
return self.encoder(x) return self.encoder(x)
@@ -843,7 +855,7 @@ class VQGanVAE(nn.Module):
# calculate adaptive weight # calculate adaptive weight
last_dec_layer = self.decoders[-1].weight last_dec_layer = self.enc_dec.last_dec_layer
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2) norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2) norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.88', version = '0.0.89',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',