let researchers do the hyperparameter search

This commit is contained in:
Phil Wang
2022-05-01 08:46:21 -07:00
parent 67fcab1122
commit 5e421bd5bb
2 changed files with 6 additions and 3 deletions

View File

@@ -1072,6 +1072,8 @@ class Unet(nn.Module):
cond_on_text_encodings = False,
max_text_len = 256,
cond_on_image_embeds = False,
init_dim = None,
init_conv_kernel_size = 7
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
@@ -1089,9 +1091,10 @@ class Unet(nn.Module):
self.channels = channels
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = dim // 2
init_dim = default(init_dim, dim // 2)
self.init_conv = nn.Conv2d(init_channels, init_dim, 7, padding = 3)
assert (init_conv_kernel_size % 2) == 1
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.84',
version = '0.0.85',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',