prepare for turning off gradient penalty, as shown in GAN literature, GP needs to be only applied 1 out of 4 iterations

This commit is contained in:
Phil Wang
2022-04-23 07:52:10 -07:00
parent 05b74be69a
commit f82917e1fd
2 changed files with 6 additions and 5 deletions

View File

@@ -477,7 +477,8 @@ class VQGanVAE(nn.Module):
img, img,
return_loss = False, return_loss = False,
return_discr_loss = False, return_discr_loss = False,
return_recons = False return_recons = False,
add_gradient_penalty = True
): ):
batch, channels, height, width, device = *img.shape, img.device batch, channels, height, width, device = *img.shape, img.device
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}' assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
@@ -502,11 +503,11 @@ class VQGanVAE(nn.Module):
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img)) fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
gp = gradient_penalty(img, img_discr_logits)
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits) discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
loss = discr_loss + gp if add_gradient_penalty:
gp = gradient_penalty(img, img_discr_logits)
loss = discr_loss + gp
if return_recons: if return_recons:
return loss, fmap return loss, fmap

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.38', version = '0.0.39',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',