mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
add weight standardization behind feature flag, which may potentially work well with group norm
This commit is contained in:
10
README.md
10
README.md
@@ -1264,4 +1264,14 @@ For detailed information on training the diffusion prior, please refer to the [d
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{Qiao2019WeightS,
|
||||||
|
title = {Weight Standardization},
|
||||||
|
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
|
||||||
|
journal = {ArXiv},
|
||||||
|
year = {2019},
|
||||||
|
volume = {abs/1903.10520}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -1451,6 +1451,26 @@ def Downsample(dim, *, dim_out = None):
|
|||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
||||||
|
|
||||||
|
class WeightStandardizedConv2d(nn.Conv2d):
|
||||||
|
"""
|
||||||
|
https://arxiv.org/abs/1903.10520
|
||||||
|
weight standardization purportedly works synergistically with group normalization
|
||||||
|
"""
|
||||||
|
def forward(self, x):
|
||||||
|
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
||||||
|
|
||||||
|
weight = self.weight
|
||||||
|
flattened_weights = rearrange(weight, 'o ... -> o (...)')
|
||||||
|
|
||||||
|
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
|
||||||
|
|
||||||
|
var = torch.var(flattened_weights, dim = -1, unbiased = False)
|
||||||
|
var = rearrange(var, 'o -> o 1 1 1')
|
||||||
|
|
||||||
|
weight = (weight - mean) * (var + eps).rsqrt()
|
||||||
|
|
||||||
|
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1469,10 +1489,13 @@ class Block(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
dim_out,
|
dim_out,
|
||||||
groups = 8
|
groups = 8,
|
||||||
|
weight_standardization = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d
|
||||||
|
|
||||||
|
self.project = conv_klass(dim, dim_out, 3, padding = 1)
|
||||||
self.norm = nn.GroupNorm(groups, dim_out)
|
self.norm = nn.GroupNorm(groups, dim_out)
|
||||||
self.act = nn.SiLU()
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
@@ -1496,6 +1519,7 @@ class ResnetBlock(nn.Module):
|
|||||||
cond_dim = None,
|
cond_dim = None,
|
||||||
time_cond_dim = None,
|
time_cond_dim = None,
|
||||||
groups = 8,
|
groups = 8,
|
||||||
|
weight_standardization = False,
|
||||||
cosine_sim_cross_attn = False
|
cosine_sim_cross_attn = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1521,8 +1545,8 @@ class ResnetBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.block1 = Block(dim, dim_out, groups = groups)
|
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
|
||||||
self.block2 = Block(dim_out, dim_out, groups = groups)
|
self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)
|
||||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, time_emb = None, cond = None):
|
def forward(self, x, time_emb = None, cond = None):
|
||||||
@@ -1747,6 +1771,7 @@ class Unet(nn.Module):
|
|||||||
init_dim = None,
|
init_dim = None,
|
||||||
init_conv_kernel_size = 7,
|
init_conv_kernel_size = 7,
|
||||||
resnet_groups = 8,
|
resnet_groups = 8,
|
||||||
|
resnet_weight_standardization = False,
|
||||||
num_resnet_blocks = 2,
|
num_resnet_blocks = 2,
|
||||||
init_cross_embed = True,
|
init_cross_embed = True,
|
||||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||||
@@ -1894,7 +1919,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# prepare resnet klass
|
# prepare resnet klass
|
||||||
|
|
||||||
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn)
|
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)
|
||||||
|
|
||||||
# give memory efficient unet an initial resnet block
|
# give memory efficient unet an initial resnet block
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.6.5'
|
__version__ = '1.7.0'
|
||||||
|
|||||||
Reference in New Issue
Block a user