Compare commits

...

3 Commits

2 changed files with 33 additions and 24 deletions

View File

@@ -137,23 +137,27 @@ def sigmoid_beta_schedule(timesteps):
# diffusion prior # diffusion prior
class RMSNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5): def __init__(self, dim, eps = 1e-5):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.scale = dim ** 0.5 self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x): def forward(self, x):
squared_sum = (x ** 2).sum(dim = -1, keepdim = True) var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps) mean = torch.mean(x, dim = 1, keepdim = True)
return x * inv_norm * self.gamma * self.scale return (x - mean) / (var + self.eps).sqrt() * self.g
class ChanRMSNorm(RMSNorm):
def forward(self, x):
squared_sum = (x ** 2).sum(dim = 1, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
class Residual(nn.Module): class Residual(nn.Module):
def __init__(self, fn): def __init__(self, fn):
@@ -249,10 +253,10 @@ def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
inner_dim = int(mult * dim) inner_dim = int(mult * dim)
return nn.Sequential( return nn.Sequential(
RMSNorm(dim), LayerNorm(dim),
nn.Linear(dim, inner_dim * 2, bias = False), nn.Linear(dim, inner_dim * 2, bias = False),
SwiGLU(), SwiGLU(),
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(), LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False) nn.Linear(inner_dim, dim, bias = False)
) )
@@ -275,7 +279,8 @@ class Attention(nn.Module):
inner_dim = dim_head * heads inner_dim = dim_head * heads
self.causal = causal self.causal = causal
self.norm = RMSNorm(dim) self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -331,7 +336,8 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v) out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out) out = self.to_out(out)
return self.post_norm(out)
class CausalTransformer(nn.Module): class CausalTransformer(nn.Module):
def __init__( def __init__(
@@ -344,7 +350,8 @@ class CausalTransformer(nn.Module):
ff_mult = 4, ff_mult = 4,
norm_out = False, norm_out = False,
attn_dropout = 0., attn_dropout = 0.,
ff_dropout = 0. ff_dropout = 0.,
final_proj = True
): ):
super().__init__() super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads) self.rel_pos_bias = RelPosBias(heads = heads)
@@ -356,7 +363,8 @@ class CausalTransformer(nn.Module):
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
])) ]))
self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
def forward( def forward(
self, self,
@@ -371,7 +379,8 @@ class CausalTransformer(nn.Module):
x = attn(x, mask = mask, attn_bias = attn_bias) + x x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = ff(x) + x x = ff(x) + x
return self.norm(x) out = self.norm(x)
return self.project_out(out)
class DiffusionPriorNetwork(nn.Module): class DiffusionPriorNetwork(nn.Module):
def __init__( def __init__(
@@ -720,7 +729,7 @@ class ConvNextBlock(nn.Module):
inner_dim = int(dim_out * mult) inner_dim = int(dim_out * mult)
self.net = nn.Sequential( self.net = nn.Sequential(
ChanRMSNorm(dim) if norm else nn.Identity(), ChanLayerNorm(dim) if norm else nn.Identity(),
nn.Conv2d(dim, inner_dim, 3, padding = 1), nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(), nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1) nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
@@ -756,8 +765,8 @@ class CrossAttention(nn.Module):
context_dim = default(context_dim, dim) context_dim = default(context_dim, dim)
self.norm = RMSNorm(dim) self.norm = LayerNorm(dim)
self.norm_context = RMSNorm(context_dim) self.norm_context = LayerNorm(context_dim)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -1075,14 +1084,14 @@ class LowresConditioner(nn.Module):
if self.training and self.downsample_first and exists(downsample_image_size): if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode) cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
if self.training: if self.training:
# when training, blur the low resolution conditional image # when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma) blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size) blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2)) cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
return cond_fmap return cond_fmap
class Decoder(nn.Module): class Decoder(nn.Module):

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.32', version = '0.0.35',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',