mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 10:14:19 +01:00
prepare for turning off gradient penalty, as shown in GAN literature, GP needs to be only applied 1 out of 4 iterations
This commit is contained in:
@@ -477,7 +477,8 @@ class VQGanVAE(nn.Module):
|
||||
img,
|
||||
return_loss = False,
|
||||
return_discr_loss = False,
|
||||
return_recons = False
|
||||
return_recons = False,
|
||||
add_gradient_penalty = True
|
||||
):
|
||||
batch, channels, height, width, device = *img.shape, img.device
|
||||
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||
@@ -502,11 +503,11 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
||||
|
||||
gp = gradient_penalty(img, img_discr_logits)
|
||||
|
||||
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
||||
|
||||
loss = discr_loss + gp
|
||||
if add_gradient_penalty:
|
||||
gp = gradient_penalty(img, img_discr_logits)
|
||||
loss = discr_loss + gp
|
||||
|
||||
if return_recons:
|
||||
return loss, fmap
|
||||
|
||||
Reference in New Issue
Block a user