mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
offer way to turn off initial cross embed convolutional module, for debugging upsampler artifacts
This commit is contained in:
@@ -1550,6 +1550,7 @@ class Unet(nn.Module):
|
|||||||
init_conv_kernel_size = 7,
|
init_conv_kernel_size = 7,
|
||||||
resnet_groups = 8,
|
resnet_groups = 8,
|
||||||
num_resnet_blocks = 2,
|
num_resnet_blocks = 2,
|
||||||
|
init_cross_embed = True,
|
||||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||||
cross_embed_downsample = False,
|
cross_embed_downsample = False,
|
||||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||||
@@ -1578,7 +1579,7 @@ class Unet(nn.Module):
|
|||||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||||
init_dim = default(init_dim, dim)
|
init_dim = default(init_dim, dim)
|
||||||
|
|
||||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||||
|
|
||||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||||
in_out = list(zip(dims[:-1], dims[1:]))
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||||||
|
|||||||
@@ -225,6 +225,7 @@ class UnetConfig(BaseModel):
|
|||||||
self_attn: ListOrTuple(int)
|
self_attn: ListOrTuple(int)
|
||||||
attn_dim_head: int = 32
|
attn_dim_head: int = 32
|
||||||
attn_heads: int = 16
|
attn_heads: int = 16
|
||||||
|
init_cross_embed: bool = True
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.24.2'
|
__version__ = '0.24.3'
|
||||||
|
|||||||
Reference in New Issue
Block a user