mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix calculation of adaptive weight for vit-vqgan, thanks to @CiaoHe
This commit is contained in:
@@ -285,6 +285,10 @@ class ResnetEncDec(nn.Module):
|
|||||||
def get_encoded_fmap_size(self, image_size):
|
def get_encoded_fmap_size(self, image_size):
|
||||||
return image_size // (2 ** self.layers)
|
return image_size // (2 ** self.layers)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_dec_layer(self):
|
||||||
|
return self.decoders[-1].weight
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
for enc in self.encoders:
|
for enc in self.encoders:
|
||||||
x = enc(x)
|
x = enc(x)
|
||||||
@@ -419,6 +423,10 @@ class ConvNextEncDec(nn.Module):
|
|||||||
def get_encoded_fmap_size(self, image_size):
|
def get_encoded_fmap_size(self, image_size):
|
||||||
return image_size // (2 ** self.layers)
|
return image_size // (2 ** self.layers)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_dec_layer(self):
|
||||||
|
return self.decoders[-1].weight
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
for enc in self.encoders:
|
for enc in self.encoders:
|
||||||
x = enc(x)
|
x = enc(x)
|
||||||
@@ -606,6 +614,10 @@ class ViTEncDec(nn.Module):
|
|||||||
def get_encoded_fmap_size(self, image_size):
|
def get_encoded_fmap_size(self, image_size):
|
||||||
return image_size // self.patch_size
|
return image_size // self.patch_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_dec_layer(self):
|
||||||
|
return self.decoder[-3][-1].weight
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
return self.encoder(x)
|
return self.encoder(x)
|
||||||
|
|
||||||
@@ -843,7 +855,7 @@ class VQGanVAE(nn.Module):
|
|||||||
|
|
||||||
# calculate adaptive weight
|
# calculate adaptive weight
|
||||||
|
|
||||||
last_dec_layer = self.decoders[-1].weight
|
last_dec_layer = self.enc_dec.last_dec_layer
|
||||||
|
|
||||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
|
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
|
||||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
|
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user