make sure some hyperparameters for unet block is configurable

This commit is contained in:
Phil Wang
2022-05-04 11:18:32 -07:00
parent 9359ad2e91
commit 5b619c2fd5
2 changed files with 5 additions and 3 deletions

View File

@@ -1196,6 +1196,8 @@ class Unet(nn.Module):
init_dim = None, init_dim = None,
init_conv_kernel_size = 7, init_conv_kernel_size = 7,
block_type = 'resnet', block_type = 'resnet',
block_resnet_groups = 8,
block_convnext_mult = 2,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -1274,9 +1276,9 @@ class Unet(nn.Module):
# whether to use resnet or the (improved?) convnext blocks # whether to use resnet or the (improved?) convnext blocks
if block_type == 'resnet': if block_type == 'resnet':
block_klass = ResnetBlock block_klass = partial(ResnetBlock, groups = block_resnet_groups)
elif block_type == 'convnext': elif block_type == 'convnext':
block_klass = ConvNextBlock block_klass = partial(ConvNextBlock, mult = block_convnext_mult)
else: else:
raise ValueError(f'unimplemented block type {block_type}') raise ValueError(f'unimplemented block type {block_type}')

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.95', version = '0.0.96',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',