diff --git a/README.md b/README.md index 7adfa7e..5826f5a 100644 --- a/README.md +++ b/README.md @@ -866,14 +866,6 @@ Once built, images will be saved to the same directory the command is invoked } ``` -```bibtex -@inproceedings{Liu2022ACF, - title = {A ConvNet for the 2020s}, - author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie}, - year = {2022} -} -``` - ```bibtex @article{shen2019efficient, author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li}, diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 35df9d8..8b9bfda 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -999,68 +999,6 @@ class ResnetBlock(nn.Module): h = self.block2(h) return h + self.res_conv(x) -class ConvNextBlock(nn.Module): - """ https://arxiv.org/abs/2201.03545 """ - - def __init__( - self, - dim, - dim_out, - *, - cond_dim = None, - time_cond_dim = None, - mult = 2 - ): - super().__init__() - need_projection = dim != dim_out - - self.cross_attn = None - - if exists(cond_dim): - self.cross_attn = EinopsToAndFrom( - 'b c h w', - 'b (h w) c', - CrossAttention( - dim = dim, - context_dim = cond_dim - ) - ) - - self.time_mlp = None - - if exists(time_cond_dim): - self.time_mlp = nn.Sequential( - nn.GELU(), - nn.Linear(time_cond_dim, dim) - ) - - self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim) - - inner_dim = int(dim_out * mult) - self.net = nn.Sequential( - ChanLayerNorm(dim), - nn.Conv2d(dim, inner_dim, 3, padding = 1), - nn.GELU(), - nn.Conv2d(inner_dim, dim_out, 3, padding = 1) - ) - - self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity() - - def forward(self, x, cond = None, time = None): - h = self.ds_conv(x) - - if exists(time) and exists(self.time_mlp): - t = self.time_mlp(time) - h = rearrange(t, 'b c -> b c 1 1') + h - - if exists(self.cross_attn): - assert exists(cond) - h = self.cross_attn(h, context = cond) + h - - h = self.net(h) - - return h + self.res_conv(x) - class CrossAttention(nn.Module): def __init__( self, @@ -1200,7 +1138,6 @@ class Unet(nn.Module): init_conv_kernel_size = 7, block_type = 'resnet', block_resnet_groups = 8, - block_convnext_mult = 2, **kwargs ): super().__init__() @@ -1276,14 +1213,9 @@ class Unet(nn.Module): attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) - # whether to use resnet or the (improved?) convnext blocks + # resnet block klass - if block_type == 'resnet': - block_klass = partial(ResnetBlock, groups = block_resnet_groups) - elif block_type == 'convnext': - block_klass = partial(ConvNextBlock, mult = block_convnext_mult) - else: - raise ValueError(f'unimplemented block type {block_type}') + block_klass = partial(ResnetBlock, groups = block_resnet_groups) # layers diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py index 7c97e6b..4d08389 100644 --- a/dalle2_pytorch/vqgan_vae.py +++ b/dalle2_pytorch/vqgan_vae.py @@ -331,112 +331,6 @@ class ResBlock(nn.Module): def forward(self, x): return self.net(x) + x -# convnext enc dec - -class ChanLayerNorm(nn.Module): - def __init__(self, dim, eps = 1e-5): - super().__init__() - self.eps = eps - self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) - - def forward(self, x): - var = torch.var(x, dim = 1, unbiased = False, keepdim = True) - mean = torch.mean(x, dim = 1, keepdim = True) - return (x - mean) / (var + self.eps).sqrt() * self.g - -class ConvNext(nn.Module): - def __init__(self, dim, mult = 4, kernel_size = 3, ds_kernel_size = 7): - super().__init__() - inner_dim = int(dim * mult) - self.net = nn.Sequential( - nn.Conv2d(dim, dim, ds_kernel_size, padding = ds_kernel_size // 2, groups = dim), - ChanLayerNorm(dim), - nn.Conv2d(dim, inner_dim, kernel_size, padding = kernel_size // 2), - nn.GELU(), - nn.Conv2d(inner_dim, dim, kernel_size, padding = kernel_size // 2) - ) - - def forward(self, x): - return self.net(x) + x - -class ConvNextEncDec(nn.Module): - def __init__( - self, - dim, - *, - channels = 3, - layers = 4, - layer_mults = None, - num_blocks = 1, - first_conv_kernel_size = 5, - use_attn = True, - attn_dim_head = 64, - attn_heads = 8, - attn_dropout = 0., - ): - super().__init__() - - self.layers = layers - - self.encoders = MList([]) - self.decoders = MList([]) - - layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers)))) - assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers' - - layer_dims = [dim * mult for mult in layer_mults] - dims = (dim, *layer_dims) - - self.encoded_dim = dims[-1] - - dim_pairs = zip(dims[:-1], dims[1:]) - - append = lambda arr, t: arr.append(t) - prepend = lambda arr, t: arr.insert(0, t) - - if not isinstance(num_blocks, tuple): - num_blocks = (*((0,) * (layers - 1)), num_blocks) - - if not isinstance(use_attn, tuple): - use_attn = (*((False,) * (layers - 1)), use_attn) - - assert len(num_blocks) == layers, 'number of blocks config must be equal to number of layers' - assert len(use_attn) == layers - - for layer_index, (dim_in, dim_out), layer_num_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_blocks, use_attn): - append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu())) - prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu())) - - if layer_use_attn: - prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout)) - - for _ in range(layer_num_blocks): - append(self.encoders, ConvNext(dim_out)) - prepend(self.decoders, ConvNext(dim_out)) - - if layer_use_attn: - append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout)) - - prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2)) - append(self.decoders, nn.Conv2d(dim, channels, 1)) - - def get_encoded_fmap_size(self, image_size): - return image_size // (2 ** self.layers) - - @property - def last_dec_layer(self): - return self.decoders[-1].weight - - def encode(self, x): - for enc in self.encoders: - x = enc(x) - return x - - def decode(self, x): - for dec in self.decoders: - x = dec(x) - return x - # vqgan attention layer class VQGanAttention(nn.Module): @@ -682,8 +576,6 @@ class VQGanVAE(nn.Module): enc_dec_klass = ResnetEncDec elif vae_type == 'vit': enc_dec_klass = ViTEncDec - elif vae_type == 'convnext': - enc_dec_klass = ConvNextEncDec else: raise ValueError(f'{vae_type} not valid') diff --git a/setup.py b/setup.py index dc5c744..7c79ae4 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.100', + version = '0.0.101', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',