mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
zero init final projection in unet, since openai and @crowsonkb are both doing it
This commit is contained in:
@@ -77,6 +77,11 @@ def cast_tuple(val, length = None):
|
||||
def module_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
def zero_init_(m):
|
||||
nn.init.zeros_(m.weight)
|
||||
if exists(m.bias):
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
@contextmanager
|
||||
def null_context(*args, **kwargs):
|
||||
yield
|
||||
@@ -1669,6 +1674,8 @@ class Unet(nn.Module):
|
||||
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
||||
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
||||
|
||||
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
def cast_model_parameters(
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.20.0'
|
||||
__version__ = '0.20.1'
|
||||
|
||||
Reference in New Issue
Block a user