mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
let researchers do the hyperparameter search
This commit is contained in:
@@ -1072,6 +1072,8 @@ class Unet(nn.Module):
|
||||
cond_on_text_encodings = False,
|
||||
max_text_len = 256,
|
||||
cond_on_image_embeds = False,
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7
|
||||
):
|
||||
super().__init__()
|
||||
# save locals to take care of some hyperparameters for cascading DDPM
|
||||
@@ -1089,9 +1091,10 @@ class Unet(nn.Module):
|
||||
self.channels = channels
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
init_dim = dim // 2
|
||||
init_dim = default(init_dim, dim // 2)
|
||||
|
||||
self.init_conv = nn.Conv2d(init_channels, init_dim, 7, padding = 3)
|
||||
assert (init_conv_kernel_size % 2) == 1
|
||||
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||
|
||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
Reference in New Issue
Block a user