mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
make sure some hyperparameters for unet block is configurable
This commit is contained in:
@@ -1196,6 +1196,8 @@ class Unet(nn.Module):
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
block_type = 'resnet',
|
||||
block_resnet_groups = 8,
|
||||
block_convnext_mult = 2,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1274,9 +1276,9 @@ class Unet(nn.Module):
|
||||
# whether to use resnet or the (improved?) convnext blocks
|
||||
|
||||
if block_type == 'resnet':
|
||||
block_klass = ResnetBlock
|
||||
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
|
||||
elif block_type == 'convnext':
|
||||
block_klass = ConvNextBlock
|
||||
block_klass = partial(ConvNextBlock, mult = block_convnext_mult)
|
||||
else:
|
||||
raise ValueError(f'unimplemented block type {block_type}')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user