mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
let researchers do the hyperparameter search
This commit is contained in:
@@ -1072,6 +1072,8 @@ class Unet(nn.Module):
|
|||||||
cond_on_text_encodings = False,
|
cond_on_text_encodings = False,
|
||||||
max_text_len = 256,
|
max_text_len = 256,
|
||||||
cond_on_image_embeds = False,
|
cond_on_image_embeds = False,
|
||||||
|
init_dim = None,
|
||||||
|
init_conv_kernel_size = 7
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# save locals to take care of some hyperparameters for cascading DDPM
|
# save locals to take care of some hyperparameters for cascading DDPM
|
||||||
@@ -1089,9 +1091,10 @@ class Unet(nn.Module):
|
|||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
|
||||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||||
init_dim = dim // 2
|
init_dim = default(init_dim, dim // 2)
|
||||||
|
|
||||||
self.init_conv = nn.Conv2d(init_channels, init_dim, 7, padding = 3)
|
assert (init_conv_kernel_size % 2) == 1
|
||||||
|
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||||
|
|
||||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||||
in_out = list(zip(dims[:-1], dims[1:]))
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||||||
|
|||||||
Reference in New Issue
Block a user