mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
prepare for turning off gradient penalty, as shown in GAN literature, GP needs to be only applied 1 out of 4 iterations
This commit is contained in:
@@ -477,7 +477,8 @@ class VQGanVAE(nn.Module):
|
|||||||
img,
|
img,
|
||||||
return_loss = False,
|
return_loss = False,
|
||||||
return_discr_loss = False,
|
return_discr_loss = False,
|
||||||
return_recons = False
|
return_recons = False,
|
||||||
|
add_gradient_penalty = True
|
||||||
):
|
):
|
||||||
batch, channels, height, width, device = *img.shape, img.device
|
batch, channels, height, width, device = *img.shape, img.device
|
||||||
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||||
@@ -502,10 +503,10 @@ class VQGanVAE(nn.Module):
|
|||||||
|
|
||||||
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
||||||
|
|
||||||
gp = gradient_penalty(img, img_discr_logits)
|
|
||||||
|
|
||||||
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
||||||
|
|
||||||
|
if add_gradient_penalty:
|
||||||
|
gp = gradient_penalty(img, img_discr_logits)
|
||||||
loss = discr_loss + gp
|
loss = discr_loss + gp
|
||||||
|
|
||||||
if return_recons:
|
if return_recons:
|
||||||
|
|||||||
Reference in New Issue
Block a user