first pass at complete DALL-E2 + Latent Diffusion integration, latent diffusion on any layer(s) of the cascading ddpm in the decoder.

This commit is contained in:
Phil Wang
2022-04-22 13:53:13 -07:00
parent f2d5b87677
commit 76b32f18b3
4 changed files with 204 additions and 30 deletions

View File

@@ -294,7 +294,7 @@ class VQGanVAE(nn.Module):
dim,
image_size,
channels = 3,
num_layers = 4,
layers = 4,
layer_mults = None,
l2_recon_loss = False,
use_hinge_loss = True,
@@ -321,35 +321,37 @@ class VQGanVAE(nn.Module):
self.image_size = image_size
self.channels = channels
self.num_layers = num_layers
self.fmap_size = image_size // (num_layers ** 2)
self.layers = layers
self.fmap_size = image_size // (layers ** 2)
self.codebook_size = vq_codebook_size
self.encoders = MList([])
self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers))))
assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers'
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
codebook_dim = layer_dims[-1]
self.encoded_dim = dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple):
num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks)
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (num_layers - 1)), use_attn)
use_attn = (*((False,) * (layers - 1)), use_attn)
assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers'
assert len(use_attn) == num_layers
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
assert len(use_attn) == layers
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn):
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
@@ -434,12 +436,15 @@ class VQGanVAE(nn.Module):
return fmap
def decode(self, fmap):
def decode(self, fmap, return_indices_and_loss = False):
fmap, indices, commit_loss = self.vq(fmap)
for dec in self.decoders:
fmap = dec(fmap)
if not return_indices_and_loss:
return fmap
return fmap, indices, commit_loss
def forward(
@@ -455,7 +460,7 @@ class VQGanVAE(nn.Module):
fmap = self.encode(img)
fmap, indices, commit_loss = self.decode(fmap)
fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
if not return_loss and not return_discr_loss:
return fmap