From 44e09d5a4d3a0fb6d9dbebb9f0ae6bbfb5eba606 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 14 Aug 2022 11:34:45 -0700 Subject: [PATCH] add weight standardization behind feature flag, which may potentially work well with group norm --- README.md | 10 +++++++++ dalle2_pytorch/dalle2_pytorch.py | 35 +++++++++++++++++++++++++++----- dalle2_pytorch/version.py | 2 +- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index cd3fb3d..6741435 100644 --- a/README.md +++ b/README.md @@ -1264,4 +1264,14 @@ For detailed information on training the diffusion prior, please refer to the [d } ``` +```bibtex +@article{Qiao2019WeightS, + title = {Weight Standardization}, + author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille}, + journal = {ArXiv}, + year = {2019}, + volume = {abs/1903.10520} +} +``` + *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index f7c3b1c..ce7339e 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1451,6 +1451,26 @@ def Downsample(dim, *, dim_out = None): dim_out = default(dim_out, dim) return nn.Conv2d(dim, dim_out, 4, 2, 1) +class WeightStandardizedConv2d(nn.Conv2d): + """ + https://arxiv.org/abs/1903.10520 + weight standardization purportedly works synergistically with group normalization + """ + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + + weight = self.weight + flattened_weights = rearrange(weight, 'o ... -> o (...)') + + mean = reduce(weight, 'o ... -> o 1 1 1', 'mean') + + var = torch.var(flattened_weights, dim = -1, unbiased = False) + var = rearrange(var, 'o -> o 1 1 1') + + weight = (weight - mean) * (var + eps).rsqrt() + + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() @@ -1469,10 +1489,13 @@ class Block(nn.Module): self, dim, dim_out, - groups = 8 + groups = 8, + weight_standardization = False ): super().__init__() - self.project = nn.Conv2d(dim, dim_out, 3, padding = 1) + conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d + + self.project = conv_klass(dim, dim_out, 3, padding = 1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() @@ -1496,6 +1519,7 @@ class ResnetBlock(nn.Module): cond_dim = None, time_cond_dim = None, groups = 8, + weight_standardization = False, cosine_sim_cross_attn = False ): super().__init__() @@ -1521,8 +1545,8 @@ class ResnetBlock(nn.Module): ) ) - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization) + self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb = None, cond = None): @@ -1747,6 +1771,7 @@ class Unet(nn.Module): init_dim = None, init_conv_kernel_size = 7, resnet_groups = 8, + resnet_weight_standardization = False, num_resnet_blocks = 2, init_cross_embed = True, init_cross_embed_kernel_sizes = (3, 7, 15), @@ -1894,7 +1919,7 @@ class Unet(nn.Module): # prepare resnet klass - resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn) + resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization) # give memory efficient unet an initial resnet block diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index f3df7f0..0e1a38d 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.6.5' +__version__ = '1.7.0'