mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
switch over to scale shift conditioning, as it seems like Imagen and Glide used it and it may be important
This commit is contained in:
@@ -1107,13 +1107,20 @@ class Block(nn.Module):
|
|||||||
groups = 8
|
groups = 8
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.block = nn.Sequential(
|
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
self.norm = nn.GroupNorm(groups, dim_out)
|
||||||
nn.GroupNorm(groups, dim_out),
|
self.act = nn.SiLU()
|
||||||
nn.SiLU()
|
|
||||||
)
|
def forward(self, x, scale_shift = None):
|
||||||
def forward(self, x):
|
x = self.project(x)
|
||||||
return self.block(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
|
if exists(scale_shift):
|
||||||
|
scale, shift = scale_shift
|
||||||
|
x = x * (scale + 1) + shift
|
||||||
|
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
||||||
|
|
||||||
class ResnetBlock(nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1132,7 +1139,7 @@ class ResnetBlock(nn.Module):
|
|||||||
if exists(time_cond_dim):
|
if exists(time_cond_dim):
|
||||||
self.time_mlp = nn.Sequential(
|
self.time_mlp = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(time_cond_dim, dim_out)
|
nn.Linear(time_cond_dim, dim_out * 2)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cross_attn = None
|
self.cross_attn = None
|
||||||
@@ -1152,11 +1159,14 @@ class ResnetBlock(nn.Module):
|
|||||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, cond = None, time_emb = None):
|
def forward(self, x, cond = None, time_emb = None):
|
||||||
h = self.block1(x)
|
|
||||||
|
|
||||||
|
scale_shift = None
|
||||||
if exists(self.time_mlp) and exists(time_emb):
|
if exists(self.time_mlp) and exists(time_emb):
|
||||||
time_emb = self.time_mlp(time_emb)
|
time_emb = self.time_mlp(time_emb)
|
||||||
h = rearrange(time_emb, 'b c -> b c 1 1') + h
|
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
|
||||||
|
scale_shift = time_emb.chunk(2, dim = 1)
|
||||||
|
|
||||||
|
h = self.block1(x, scale_shift = scale_shift)
|
||||||
|
|
||||||
if exists(self.cross_attn):
|
if exists(self.cross_attn):
|
||||||
assert exists(cond)
|
assert exists(cond)
|
||||||
|
|||||||
Reference in New Issue
Block a user