From f82917e1fdf1112cda2602265b9ff6d52fe50d7b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 23 Apr 2022 07:52:10 -0700 Subject: [PATCH] prepare for turning off gradient penalty, as shown in GAN literature, GP needs to be only applied 1 out of 4 iterations --- dalle2_pytorch/vqgan_vae.py | 9 +++++---- setup.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py index 380cd42..8441237 100644 --- a/dalle2_pytorch/vqgan_vae.py +++ b/dalle2_pytorch/vqgan_vae.py @@ -477,7 +477,8 @@ class VQGanVAE(nn.Module): img, return_loss = False, return_discr_loss = False, - return_recons = False + return_recons = False, + add_gradient_penalty = True ): batch, channels, height, width, device = *img.shape, img.device assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}' @@ -502,11 +503,11 @@ class VQGanVAE(nn.Module): fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img)) - gp = gradient_penalty(img, img_discr_logits) - discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits) - loss = discr_loss + gp + if add_gradient_penalty: + gp = gradient_penalty(img, img_discr_logits) + loss = discr_loss + gp if return_recons: return loss, fmap diff --git a/setup.py b/setup.py index 9c3781b..143c52d 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.38', + version = '0.0.39', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',