mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-22 02:04:24 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f82917e1fd |
@@ -477,7 +477,8 @@ class VQGanVAE(nn.Module):
|
|||||||
img,
|
img,
|
||||||
return_loss = False,
|
return_loss = False,
|
||||||
return_discr_loss = False,
|
return_discr_loss = False,
|
||||||
return_recons = False
|
return_recons = False,
|
||||||
|
add_gradient_penalty = True
|
||||||
):
|
):
|
||||||
batch, channels, height, width, device = *img.shape, img.device
|
batch, channels, height, width, device = *img.shape, img.device
|
||||||
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||||
@@ -502,11 +503,11 @@ class VQGanVAE(nn.Module):
|
|||||||
|
|
||||||
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
||||||
|
|
||||||
gp = gradient_penalty(img, img_discr_logits)
|
|
||||||
|
|
||||||
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
||||||
|
|
||||||
loss = discr_loss + gp
|
if add_gradient_penalty:
|
||||||
|
gp = gradient_penalty(img, img_discr_logits)
|
||||||
|
loss = discr_loss + gp
|
||||||
|
|
||||||
if return_recons:
|
if return_recons:
|
||||||
return loss, fmap
|
return loss, fmap
|
||||||
|
|||||||
Reference in New Issue
Block a user