zero init final projection in unet, since openai and @crowsonkb are both doing it

This commit is contained in:
Phil Wang
2022-07-11 13:22:06 -07:00
parent 1f1557c614
commit bdd62c24b3
2 changed files with 8 additions and 1 deletions

View File

@@ -77,6 +77,11 @@ def cast_tuple(val, length = None):
def module_device(module): def module_device(module):
return next(module.parameters()).device return next(module.parameters()).device
def zero_init_(m):
nn.init.zeros_(m.weight)
if exists(m.bias):
nn.init.zeros_(m.bias)
@contextmanager @contextmanager
def null_context(*args, **kwargs): def null_context(*args, **kwargs):
yield yield
@@ -1669,6 +1674,8 @@ class Unet(nn.Module):
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2) self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
# if the current settings for the unet are not correct # if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings # for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters( def cast_model_parameters(

View File

@@ -1 +1 @@
__version__ = '0.20.0' __version__ = '0.20.1'