complete naive conditioning of unet with image embedding, with ability to dropout for classifier free guidance

This commit is contained in:
Phil Wang
2022-04-12 17:27:39 -07:00
parent d546a615c0
commit 25d980ebbf

View File

@@ -231,7 +231,7 @@ class DiffusionPriorNetwork(nn.Module):
# classifier free guidance # classifier free guidance
cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device) cond_prob_mask = prob_mask_like((batch_size,), cond_prob_drop, device = device)
mask &= rearrange(cond_prob_mask, 'b -> b 1') mask &= rearrange(cond_prob_mask, 'b -> b 1')
# attend # attend
@@ -290,7 +290,7 @@ class ConvNextBlock(nn.Module):
dim, dim,
dim_out, dim_out,
*, *,
time_emb_dim = None, cond_dim = None,
mult = 2, mult = 2,
norm = True norm = True
): ):
@@ -299,8 +299,8 @@ class ConvNextBlock(nn.Module):
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
nn.GELU(), nn.GELU(),
nn.Linear(time_emb_dim, dim) nn.Linear(cond_dim, dim)
) if exists(time_emb_dim) else None ) if exists(cond_dim) else None
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim) self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
@@ -314,12 +314,12 @@ class ConvNextBlock(nn.Module):
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity() self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
def forward(self, x, time_emb = None): def forward(self, x, cond = None):
h = self.ds_conv(x) h = self.ds_conv(x)
if exists(self.mlp): if exists(self.mlp):
assert exists(time_emb) assert exists(cond)
condition = self.mlp(time_emb) condition = self.mlp(cond)
h = h + rearrange(condition, 'b c -> b c 1 1') h = h + rearrange(condition, 'b c -> b c 1 1')
h = self.net(h) h = self.net(h)
@@ -331,10 +331,10 @@ class Unet(nn.Module):
dim, dim,
*, *,
image_embed_dim, image_embed_dim,
time_dim = None,
out_dim = None, out_dim = None,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
channels = 3, channels = 3,
with_time_emb = True
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@@ -342,17 +342,18 @@ class Unet(nn.Module):
dims = [channels, *map(lambda m: dim * m, dim_mults)] dims = [channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
if with_time_emb: time_dim = default(time_dim, dim)
time_dim = dim
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim), SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4), nn.Linear(dim, dim * 4),
nn.GELU(), nn.GELU(),
nn.Linear(dim * 4, dim) nn.Linear(dim * 4, dim)
) )
else:
time_dim = None self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim))
self.time_mlp = None
cond_dim = time_dim + image_embed_dim
self.downs = nn.ModuleList([]) self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([]) self.ups = nn.ModuleList([])
@@ -362,20 +363,20 @@ class Unet(nn.Module):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0), ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0),
ConvNextBlock(dim_out, dim_out, time_emb_dim = time_dim), ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
Downsample(dim_out) if not is_last else nn.Identity() Downsample(dim_out) if not is_last else nn.Identity()
])) ]))
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim) self.mid_block = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([ self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim), ConvNextBlock(dim_out * 2, dim_in, cond_dim = cond_dim),
ConvNextBlock(dim_in, dim_in, time_emb_dim = time_dim), ConvNextBlock(dim_in, dim_in, cond_dim = cond_dim),
Upsample(dim_in) if not is_last else nn.Identity() Upsample(dim_in) if not is_last else nn.Identity()
])) ]))
@@ -408,10 +409,20 @@ class Unet(nn.Module):
text_encodings = None, text_encodings = None,
cond_prob_drop = 0. cond_prob_drop = 0.
): ):
batch_size, device = image_embed.shape[0], image_embed.device t = self.time_mlp(time)
t = self.time_mlp(time) if exists(self.time_mlp) else None
cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device) cond_prob_mask = prob_mask_like((batch_size,), cond_prob_drop, device = device)
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_embed = torch.where(
rearrange(cond_prob_mask, 'b -> b 1'),
image_embed,
rearrange(self.null_image_embed, 'd -> 1 d')
)
cond = torch.cat((t, image_embed), dim = -1)
hiddens = [] hiddens = []