diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 24cd6e9..d388ba3 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -77,6 +77,11 @@ def cast_tuple(val, length = None): def module_device(module): return next(module.parameters()).device +def zero_init_(m): + nn.init.zeros_(m.weight) + if exists(m.bias): + nn.init.zeros_(m.bias) + @contextmanager def null_context(*args, **kwargs): yield @@ -1669,6 +1674,8 @@ class Unet(nn.Module): self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2) + zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it + # if the current settings for the unet are not correct # for cascading DDPM, then reinit the unet with the right settings def cast_model_parameters( diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 2f15b8c..abadaef 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.20.0' +__version__ = '0.20.1'