mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
complete naive conditioning of unet with image embedding, with ability to dropout for classifier free guidance
This commit is contained in:
@@ -231,7 +231,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
# classifier free guidance
|
# classifier free guidance
|
||||||
|
|
||||||
cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device)
|
cond_prob_mask = prob_mask_like((batch_size,), cond_prob_drop, device = device)
|
||||||
mask &= rearrange(cond_prob_mask, 'b -> b 1')
|
mask &= rearrange(cond_prob_mask, 'b -> b 1')
|
||||||
|
|
||||||
# attend
|
# attend
|
||||||
@@ -290,7 +290,7 @@ class ConvNextBlock(nn.Module):
|
|||||||
dim,
|
dim,
|
||||||
dim_out,
|
dim_out,
|
||||||
*,
|
*,
|
||||||
time_emb_dim = None,
|
cond_dim = None,
|
||||||
mult = 2,
|
mult = 2,
|
||||||
norm = True
|
norm = True
|
||||||
):
|
):
|
||||||
@@ -299,8 +299,8 @@ class ConvNextBlock(nn.Module):
|
|||||||
|
|
||||||
self.mlp = nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(time_emb_dim, dim)
|
nn.Linear(cond_dim, dim)
|
||||||
) if exists(time_emb_dim) else None
|
) if exists(cond_dim) else None
|
||||||
|
|
||||||
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
||||||
|
|
||||||
@@ -314,12 +314,12 @@ class ConvNextBlock(nn.Module):
|
|||||||
|
|
||||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, time_emb = None):
|
def forward(self, x, cond = None):
|
||||||
h = self.ds_conv(x)
|
h = self.ds_conv(x)
|
||||||
|
|
||||||
if exists(self.mlp):
|
if exists(self.mlp):
|
||||||
assert exists(time_emb)
|
assert exists(cond)
|
||||||
condition = self.mlp(time_emb)
|
condition = self.mlp(cond)
|
||||||
h = h + rearrange(condition, 'b c -> b c 1 1')
|
h = h + rearrange(condition, 'b c -> b c 1 1')
|
||||||
|
|
||||||
h = self.net(h)
|
h = self.net(h)
|
||||||
@@ -331,10 +331,10 @@ class Unet(nn.Module):
|
|||||||
dim,
|
dim,
|
||||||
*,
|
*,
|
||||||
image_embed_dim,
|
image_embed_dim,
|
||||||
|
time_dim = None,
|
||||||
out_dim = None,
|
out_dim = None,
|
||||||
dim_mults=(1, 2, 4, 8),
|
dim_mults=(1, 2, 4, 8),
|
||||||
channels = 3,
|
channels = 3,
|
||||||
with_time_emb = True
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
@@ -342,17 +342,18 @@ class Unet(nn.Module):
|
|||||||
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
||||||
in_out = list(zip(dims[:-1], dims[1:]))
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||||||
|
|
||||||
if with_time_emb:
|
time_dim = default(time_dim, dim)
|
||||||
time_dim = dim
|
|
||||||
self.time_mlp = nn.Sequential(
|
self.time_mlp = nn.Sequential(
|
||||||
SinusoidalPosEmb(dim),
|
SinusoidalPosEmb(dim),
|
||||||
nn.Linear(dim, dim * 4),
|
nn.Linear(dim, dim * 4),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(dim * 4, dim)
|
nn.Linear(dim * 4, dim)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
time_dim = None
|
self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim))
|
||||||
self.time_mlp = None
|
|
||||||
|
cond_dim = time_dim + image_embed_dim
|
||||||
|
|
||||||
self.downs = nn.ModuleList([])
|
self.downs = nn.ModuleList([])
|
||||||
self.ups = nn.ModuleList([])
|
self.ups = nn.ModuleList([])
|
||||||
@@ -362,20 +363,20 @@ class Unet(nn.Module):
|
|||||||
is_last = ind >= (num_resolutions - 1)
|
is_last = ind >= (num_resolutions - 1)
|
||||||
|
|
||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0),
|
ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0),
|
||||||
ConvNextBlock(dim_out, dim_out, time_emb_dim = time_dim),
|
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
|
||||||
Downsample(dim_out) if not is_last else nn.Identity()
|
Downsample(dim_out) if not is_last else nn.Identity()
|
||||||
]))
|
]))
|
||||||
|
|
||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
self.mid_block = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim)
|
self.mid_block = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||||
|
|
||||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
is_last = ind >= (num_resolutions - 1)
|
is_last = ind >= (num_resolutions - 1)
|
||||||
|
|
||||||
self.ups.append(nn.ModuleList([
|
self.ups.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim),
|
ConvNextBlock(dim_out * 2, dim_in, cond_dim = cond_dim),
|
||||||
ConvNextBlock(dim_in, dim_in, time_emb_dim = time_dim),
|
ConvNextBlock(dim_in, dim_in, cond_dim = cond_dim),
|
||||||
Upsample(dim_in) if not is_last else nn.Identity()
|
Upsample(dim_in) if not is_last else nn.Identity()
|
||||||
]))
|
]))
|
||||||
|
|
||||||
@@ -408,10 +409,20 @@ class Unet(nn.Module):
|
|||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
cond_prob_drop = 0.
|
cond_prob_drop = 0.
|
||||||
):
|
):
|
||||||
batch_size, device = image_embed.shape[0], image_embed.device
|
t = self.time_mlp(time)
|
||||||
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
|
||||||
|
|
||||||
cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device)
|
cond_prob_mask = prob_mask_like((batch_size,), cond_prob_drop, device = device)
|
||||||
|
|
||||||
|
# mask out image embedding depending on condition dropout
|
||||||
|
# for classifier free guidance
|
||||||
|
|
||||||
|
image_embed = torch.where(
|
||||||
|
rearrange(cond_prob_mask, 'b -> b 1'),
|
||||||
|
image_embed,
|
||||||
|
rearrange(self.null_image_embed, 'd -> 1 d')
|
||||||
|
)
|
||||||
|
|
||||||
|
cond = torch.cat((t, image_embed), dim = -1)
|
||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user