diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 2cb9811..dc35304 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1196,6 +1196,8 @@ class Unet(nn.Module): init_dim = None, init_conv_kernel_size = 7, block_type = 'resnet', + block_resnet_groups = 8, + block_convnext_mult = 2, **kwargs ): super().__init__() @@ -1274,9 +1276,9 @@ class Unet(nn.Module): # whether to use resnet or the (improved?) convnext blocks if block_type == 'resnet': - block_klass = ResnetBlock + block_klass = partial(ResnetBlock, groups = block_resnet_groups) elif block_type == 'convnext': - block_klass = ConvNextBlock + block_klass = partial(ConvNextBlock, mult = block_convnext_mult) else: raise ValueError(f'unimplemented block type {block_type}') diff --git a/setup.py b/setup.py index 33a9629..5ee48dc 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.95', + version = '0.0.96', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',