use null container pattern to cleanup some conditionals, save more cleanup for next week

This commit is contained in:
Phil Wang
2022-04-22 15:23:11 -07:00
parent a8b5d5d753
commit 05b74be69a
3 changed files with 44 additions and 18 deletions

View File

@@ -287,6 +287,28 @@ class VQGanAttention(nn.Module):
return out + residual
class NullVQGanVAE(nn.Module):
def __init__(
self,
*,
channels
):
super().__init__()
self.encoded_dim = channels
self.layers = 0
def get_encoded_fmap_size(self, size):
return size
def copy_for_eval(self):
return self
def encode(self, x):
return x
def decode(self, x):
return x
class VQGanVAE(nn.Module):
def __init__(
self,
@@ -407,6 +429,9 @@ class VQGanVAE(nn.Module):
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
def copy_for_eval(self):
device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu())