diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 9ad5025..fbfeb88 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1072,6 +1072,8 @@ class Unet(nn.Module): cond_on_text_encodings = False, max_text_len = 256, cond_on_image_embeds = False, + init_dim = None, + init_conv_kernel_size = 7 ): super().__init__() # save locals to take care of some hyperparameters for cascading DDPM @@ -1089,9 +1091,10 @@ class Unet(nn.Module): self.channels = channels init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis - init_dim = dim // 2 + init_dim = default(init_dim, dim // 2) - self.init_conv = nn.Conv2d(init_channels, init_dim, 7, padding = 3) + assert (init_conv_kernel_size % 2) == 1 + self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) diff --git a/setup.py b/setup.py index e3ed51a..82ae401 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.84', + version = '0.0.85', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',