make sure some hyperparameters for unet block is configurable

This commit is contained in:
Phil Wang
2022-05-04 11:18:32 -07:00
parent 9359ad2e91
commit 5b619c2fd5
2 changed files with 5 additions and 3 deletions

View File

@@ -1196,6 +1196,8 @@ class Unet(nn.Module):
init_dim = None,
init_conv_kernel_size = 7,
block_type = 'resnet',
block_resnet_groups = 8,
block_convnext_mult = 2,
**kwargs
):
super().__init__()
@@ -1274,9 +1276,9 @@ class Unet(nn.Module):
# whether to use resnet or the (improved?) convnext blocks
if block_type == 'resnet':
block_klass = ResnetBlock
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
elif block_type == 'convnext':
block_klass = ConvNextBlock
block_klass = partial(ConvNextBlock, mult = block_convnext_mult)
else:
raise ValueError(f'unimplemented block type {block_type}')

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.95',
version = '0.0.96',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',