let researchers do the hyperparameter search

This commit is contained in:
Phil Wang
2022-05-01 08:46:21 -07:00
parent 67fcab1122
commit 5e421bd5bb
2 changed files with 6 additions and 3 deletions

View File

@@ -1072,6 +1072,8 @@ class Unet(nn.Module):
cond_on_text_encodings = False, cond_on_text_encodings = False,
max_text_len = 256, max_text_len = 256,
cond_on_image_embeds = False, cond_on_image_embeds = False,
init_dim = None,
init_conv_kernel_size = 7
): ):
super().__init__() super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM # save locals to take care of some hyperparameters for cascading DDPM
@@ -1089,9 +1091,10 @@ class Unet(nn.Module):
self.channels = channels self.channels = channels
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = dim // 2 init_dim = default(init_dim, dim // 2)
self.init_conv = nn.Conv2d(init_channels, init_dim, 7, padding = 3) assert (init_conv_kernel_size % 2) == 1
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)] dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.84', version = '0.0.85',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',