Compare commits

..

2 Commits

4 changed files with 34 additions and 17 deletions

View File

@@ -1107,13 +1107,20 @@ class Block(nn.Module):
groups = 8
):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
nn.GroupNorm(groups, dim_out),
nn.SiLU()
)
def forward(self, x):
return self.block(x)
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.project(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
def __init__(
@@ -1132,7 +1139,7 @@ class ResnetBlock(nn.Module):
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out)
nn.Linear(time_cond_dim, dim_out * 2)
)
self.cross_attn = None
@@ -1152,11 +1159,14 @@ class ResnetBlock(nn.Module):
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, cond = None, time_emb = None):
h = self.block1(x)
scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
h = rearrange(time_emb, 'b c -> b c 1 1') + h
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
if exists(self.cross_attn):
assert exists(cond)

View File

@@ -12,6 +12,7 @@ def get_optimizer(
betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True,
**kwargs
):
if filter_by_requires_grad:
@@ -21,11 +22,13 @@ def get_optimizer(
return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
param_groups = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
if group_wd_params:
wd_params, no_wd_params = separate_weight_decayable_params(params)
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
params = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -254,6 +254,7 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6,
max_grad_norm = None,
amp = False,
group_wd_params = True,
**kwargs
):
super().__init__()
@@ -279,6 +280,7 @@ class DiffusionPriorTrainer(nn.Module):
lr = lr,
wd = wd,
eps = eps,
group_wd_params = group_wd_params,
**kwargs
)
@@ -410,6 +412,7 @@ class DecoderTrainer(nn.Module):
eps = 1e-8,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
**kwargs
):
super().__init__()
@@ -435,6 +438,7 @@ class DecoderTrainer(nn.Module):
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs
)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.5.0',
version = '0.5.2',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',