mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
switch over to scale shift conditioning, as it seems like Imagen and Glide used it and it may be important
This commit is contained in:
@@ -1107,13 +1107,20 @@ class Block(nn.Module):
|
||||
groups = 8
|
||||
):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
||||
nn.GroupNorm(groups, dim_out),
|
||||
nn.SiLU()
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||
self.norm = nn.GroupNorm(groups, dim_out)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x, scale_shift = None):
|
||||
x = self.project(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if exists(scale_shift):
|
||||
scale, shift = scale_shift
|
||||
x = x * (scale + 1) + shift
|
||||
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
@@ -1132,7 +1139,7 @@ class ResnetBlock(nn.Module):
|
||||
if exists(time_cond_dim):
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_cond_dim, dim_out)
|
||||
nn.Linear(time_cond_dim, dim_out * 2)
|
||||
)
|
||||
|
||||
self.cross_attn = None
|
||||
@@ -1152,11 +1159,14 @@ class ResnetBlock(nn.Module):
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, cond = None, time_emb = None):
|
||||
h = self.block1(x)
|
||||
|
||||
scale_shift = None
|
||||
if exists(self.time_mlp) and exists(time_emb):
|
||||
time_emb = self.time_mlp(time_emb)
|
||||
h = rearrange(time_emb, 'b c -> b c 1 1') + h
|
||||
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
|
||||
scale_shift = time_emb.chunk(2, dim = 1)
|
||||
|
||||
h = self.block1(x, scale_shift = scale_shift)
|
||||
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
|
||||
Reference in New Issue
Block a user