mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
make sure some hyperparameters for unet block is configurable
This commit is contained in:
@@ -1196,6 +1196,8 @@ class Unet(nn.Module):
|
|||||||
init_dim = None,
|
init_dim = None,
|
||||||
init_conv_kernel_size = 7,
|
init_conv_kernel_size = 7,
|
||||||
block_type = 'resnet',
|
block_type = 'resnet',
|
||||||
|
block_resnet_groups = 8,
|
||||||
|
block_convnext_mult = 2,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1274,9 +1276,9 @@ class Unet(nn.Module):
|
|||||||
# whether to use resnet or the (improved?) convnext blocks
|
# whether to use resnet or the (improved?) convnext blocks
|
||||||
|
|
||||||
if block_type == 'resnet':
|
if block_type == 'resnet':
|
||||||
block_klass = ResnetBlock
|
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
|
||||||
elif block_type == 'convnext':
|
elif block_type == 'convnext':
|
||||||
block_klass = ConvNextBlock
|
block_klass = partial(ConvNextBlock, mult = block_convnext_mult)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'unimplemented block type {block_type}')
|
raise ValueError(f'unimplemented block type {block_type}')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user