diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 5dc4f41..7bdfee3 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1107,13 +1107,20 @@ class Block(nn.Module): groups = 8 ): super().__init__() - self.block = nn.Sequential( - nn.Conv2d(dim, dim_out, 3, padding = 1), - nn.GroupNorm(groups, dim_out), - nn.SiLU() - ) - def forward(self, x): - return self.block(x) + self.project = nn.Conv2d(dim, dim_out, 3, padding = 1) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift = None): + x = self.project(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.act(x) + return x class ResnetBlock(nn.Module): def __init__( @@ -1132,7 +1139,7 @@ class ResnetBlock(nn.Module): if exists(time_cond_dim): self.time_mlp = nn.Sequential( nn.SiLU(), - nn.Linear(time_cond_dim, dim_out) + nn.Linear(time_cond_dim, dim_out * 2) ) self.cross_attn = None @@ -1152,11 +1159,14 @@ class ResnetBlock(nn.Module): self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, cond = None, time_emb = None): - h = self.block1(x) + scale_shift = None if exists(self.time_mlp) and exists(time_emb): time_emb = self.time_mlp(time_emb) - h = rearrange(time_emb, 'b c -> b c 1 1') + h + time_emb = rearrange(time_emb, 'b c -> b c 1 1') + scale_shift = time_emb.chunk(2, dim = 1) + + h = self.block1(x, scale_shift = scale_shift) if exists(self.cross_attn): assert exists(cond) diff --git a/setup.py b/setup.py index e805eda..7cd4afc 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.5.1', + version = '0.5.2', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',