Compare commits

..

12 Commits
1.0.1 ... 1.4.0

Author SHA1 Message Date
Phil Wang
748c7fe7af allow for cosine sim cross attention, modify linear attention in attempt to resolve issue on fp16 2022-07-29 11:12:18 -07:00
Phil Wang
80046334ad make sure entire readme runs without errors 2022-07-28 10:17:43 -07:00
Phil Wang
36fb46a95e fix readme and a small bug in DALLE2 class 2022-07-28 08:33:51 -07:00
Phil Wang
07abfcf45b rescale values in linear attention to mitigate overflows in fp16 setting 2022-07-27 12:27:38 -07:00
Phil Wang
2e35a9967d product management 2022-07-26 11:10:16 -07:00
Phil Wang
406e75043f add upsample combiner feature for the unets 2022-07-26 10:46:04 -07:00
Phil Wang
9646dfc0e6 fix path_or_state bug 2022-07-26 09:47:54 -07:00
Phil Wang
62043acb2f fix repaint 2022-07-24 15:29:06 -07:00
Phil Wang
417ff808e6 1.0.3 2022-07-22 13:16:57 -07:00
Aidan Dempster
f3d7e226ba Changed types to be generic instead of functions (#215)
This allows pylance to do proper type hinting and makes developing
extensions to the package much easier
2022-07-22 13:16:29 -07:00
Phil Wang
48a1302428 1.0.2 2022-07-20 23:01:51 -07:00
Aidan Dempster
ccaa46b81b Re-introduced change that was accidentally rolled back (#212) 2022-07-20 23:01:19 -07:00
6 changed files with 220 additions and 110 deletions

View File

@@ -371,6 +371,7 @@ loss.backward()
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8),
@@ -395,7 +396,7 @@ decoder = Decoder(
).cuda()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
@@ -860,25 +861,23 @@ unet1 = Unet(
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True,
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_text_encodings = True
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 1000,
condition_on_text_encodings = True
timesteps = 1000
).cuda()
decoder_trainer = DecoderTrainer(
@@ -903,8 +902,8 @@ for unet_number in (1, 2):
# after much training
# you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
mock_image_embed = torch.randn(32, 512).cuda()
images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
```
### Diffusion Prior Training
@@ -1112,7 +1111,8 @@ For detailed information on training the diffusion prior, please refer to the [d
- [x] allow for unet to be able to condition non-cross attention style as well
- [x] speed up inference, read up on papers (ddim)
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
## Citations

View File

@@ -516,6 +516,17 @@ class NoiseScheduler(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
shape = x_from.shape
noise = default(noise, lambda: torch.randn_like(x_from))
alpha = extract(self.sqrt_alphas_cumprod, from_t, shape)
sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape)
alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape)
sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape)
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
@@ -1346,7 +1357,8 @@ class ResnetBlock(nn.Module):
*,
cond_dim = None,
time_cond_dim = None,
groups = 8
groups = 8,
cosine_sim_cross_attn = False
):
super().__init__()
@@ -1366,7 +1378,8 @@ class ResnetBlock(nn.Module):
'b (h w) c',
CrossAttention(
dim = dim_out,
context_dim = cond_dim
context_dim = cond_dim,
cosine_sim = cosine_sim_cross_attn
)
)
@@ -1401,11 +1414,12 @@ class CrossAttention(nn.Module):
heads = 8,
dropout = 0.,
norm_context = False,
pb_relax_alpha = 32 ** 2
cosine_sim = False,
cosine_sim_scale = 16
):
super().__init__()
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.cosine_sim = cosine_sim
self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
self.heads = heads
inner_dim = dim_head * heads
@@ -1441,7 +1455,10 @@ class CrossAttention(nn.Module):
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
if self.cosine_sim:
q, k = map(l2norm, (q, k))
q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
sim = einsum('b h i d, b h j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max
@@ -1451,9 +1468,6 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
@@ -1483,6 +1497,7 @@ class LinearAttention(nn.Module):
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
seq_len = x * y
fmap = self.norm(fmap)
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
@@ -1492,6 +1507,9 @@ class LinearAttention(nn.Module):
k = k.softmax(dim = -2)
q = q * self.scale
v = l2norm(v)
k, v = map(lambda t: t / math.sqrt(seq_len), (k, v))
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
@@ -1527,6 +1545,38 @@ class CrossEmbedLayer(nn.Module):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
class UpsampleCombiner(nn.Module):
def __init__(
self,
dim,
*,
enabled = False,
dim_ins = tuple(),
dim_outs = tuple()
):
super().__init__()
assert len(dim_ins) == len(dim_outs)
self.enabled = enabled
if not self.enabled:
self.dim_out = dim
return
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
def forward(self, x, fmaps = None):
target_size = x.shape[-1]
fmaps = default(fmaps, tuple())
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
return x
fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
return torch.cat((x, *outs), dim = 1)
class Unet(nn.Module):
def __init__(
self,
@@ -1547,6 +1597,7 @@ class Unet(nn.Module):
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
sparse_attn = False,
cosine_sim_cross_attn = False,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
max_text_len = 256,
@@ -1564,6 +1615,7 @@ class Unet(nn.Module):
scale_skip_connection = False,
pixel_shuffle_upsample = True,
final_conv_kernel_size = 1,
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
**kwargs
):
super().__init__()
@@ -1689,9 +1741,13 @@ class Unet(nn.Module):
upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# prepare resnet klass
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn)
# give memory efficient unet an initial resnet block
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
self.init_resnet_block = resnet_block(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
# layers
@@ -1699,7 +1755,8 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
skip_connect_dims = [] # keeping track of skip connection dimensions
skip_connect_dims = [] # keeping track of skip connection dimensions
upsample_combiner_dims = [] # keeping track of dimensions for final upsample feature map combiner
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
is_first = ind == 0
@@ -1717,17 +1774,17 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
resnet_block(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([resnet_block(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention,
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
]))
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_block1 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = create_self_attn(mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_block2 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
is_last = ind >= (len(in_out) - 1)
@@ -1741,14 +1798,27 @@ class Unet(nn.Module):
elif sparse_attn:
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
upsample_combiner_dims.append(dim_out)
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention,
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
]))
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
# whether to combine outputs from all upsample blocks for final resnet block
self.upsample_combiner = UpsampleCombiner(
dim = dim,
enabled = combine_upsample_fmaps,
dim_ins = upsample_combiner_dims,
dim_outs = (dim,) * len(upsample_combiner_dims)
)
# a final resnet block
self.final_resnet_block = resnet_block(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
out_dim_in = dim + (channels if lowres_cond else 0)
@@ -1772,7 +1842,7 @@ class Unet(nn.Module):
channels == self.channels and \
cond_on_image_embeds == self.cond_on_image_embeds and \
cond_on_text_encodings == self.cond_on_text_encodings and \
cond_on_lowres_noise == self.cond_on_lowres_noise and \
lowres_noise_cond == self.lowres_noise_cond and \
channels_out == self.channels_out:
return self
@@ -1942,7 +2012,8 @@ class Unet(nn.Module):
# go through the layers of the unet, down and up
hiddens = []
down_hiddens = []
up_hiddens = []
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
if exists(pre_downsample):
@@ -1952,10 +2023,10 @@ class Unet(nn.Module):
for resnet_block in resnet_blocks:
x = resnet_block(x, t, c)
hiddens.append(x)
down_hiddens.append(x.contiguous())
x = attn(x)
hiddens.append(x.contiguous())
down_hiddens.append(x.contiguous())
if exists(post_downsample):
x = post_downsample(x)
@@ -1967,7 +2038,7 @@ class Unet(nn.Module):
x = self.mid_block2(x, t, mid_c)
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
for init_block, resnet_blocks, attn, upsample in self.ups:
x = connect_skip(x)
@@ -1978,8 +2049,12 @@ class Unet(nn.Module):
x = resnet_block(x, t, c)
x = attn(x)
up_hiddens.append(x.contiguous())
x = upsample(x)
x = self.upsample_combiner(x, up_hiddens)
x = torch.cat((x, r), dim = 1)
x = self.final_resnet_block(x, t)
@@ -2432,14 +2507,18 @@ class Decoder(nn.Module):
is_latent_diffusion = False,
lowres_noise_level = None,
inpaint_image = None,
inpaint_mask = None
inpaint_mask = None,
inpaint_resample_times = 5
):
device = self.device
b = shape[0]
img = torch.randn(shape, device = device)
if exists(inpaint_image):
is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1
if is_inpaint:
inpaint_image = self.normalize_img(inpaint_image)
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
@@ -2449,31 +2528,40 @@ class Decoder(nn.Module):
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long)
for time in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
is_last_timestep = time == 0
if exists(inpaint_image):
# following the repaint paper
# https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
for r in reversed(range(0, resample_times)):
is_last_resample_step = r == 0
img = self.p_sample(
unet,
img,
times,
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start,
noise_scheduler = noise_scheduler,
learned_variance = learned_variance,
clip_denoised = clip_denoised
)
times = torch.full((b,), time, device = device, dtype = torch.long)
if exists(inpaint_image):
if is_inpaint:
# following the repaint paper
# https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
img = self.p_sample(
unet,
img,
times,
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start,
noise_scheduler = noise_scheduler,
learned_variance = learned_variance,
clip_denoised = clip_denoised
)
if is_inpaint and not (is_last_timestep or is_last_resample_step):
# in repaint, you renoise and resample up to 10 times every step
img = noise_scheduler.q_sample_from_to(img, times - 1, times)
if is_inpaint:
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
unnormalize_img = self.unnormalize_img(img)
@@ -2497,7 +2585,8 @@ class Decoder(nn.Module):
is_latent_diffusion = False,
lowres_noise_level = None,
inpaint_image = None,
inpaint_mask = None
inpaint_mask = None,
inpaint_resample_times = 5
):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
@@ -2506,7 +2595,10 @@ class Decoder(nn.Module):
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
if exists(inpaint_image):
is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1
if is_inpaint:
inpaint_image = self.normalize_img(inpaint_image)
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
@@ -2519,39 +2611,49 @@ class Decoder(nn.Module):
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
is_last_timestep = time_next == 0
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
for r in reversed(range(0, resample_times)):
is_last_resample_step = r == 0
if exists(inpaint_image):
# following the repaint paper
# https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
alpha = alphas[time]
alpha_next = alphas[time_next]
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
if is_inpaint:
# following the repaint paper
# https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
if predict_x_start:
x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
if clip_denoised:
x_start = self.dynamic_threshold(x_start)
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if time_next > 0 else 0.
if predict_x_start:
x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
img = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
if clip_denoised:
x_start = self.dynamic_threshold(x_start)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if not is_last_timestep else 0.
img = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
if is_inpaint and not (is_last_timestep or is_last_resample_step):
# in repaint, you renoise and resample up to 10 times every step
time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)
img = noise_scheduler.q_sample_from_to(img, time_next_cond, time_cond)
if exists(inpaint_image):
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
@@ -2658,7 +2760,8 @@ class Decoder(nn.Module):
stop_at_unet_number = None,
distributed = False,
inpaint_image = None,
inpaint_mask = None
inpaint_mask = None,
inpaint_resample_times = 5
):
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
@@ -2730,7 +2833,8 @@ class Decoder(nn.Module):
noise_scheduler = noise_scheduler,
timesteps = sample_timesteps,
inpaint_image = inpaint_image,
inpaint_mask = inpaint_mask
inpaint_mask = inpaint_mask,
inpaint_resample_times = inpaint_resample_times
)
img = vae.decode(img)
@@ -2845,7 +2949,7 @@ class DALLE2(nn.Module):
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)
if return_pil_images:
images = list(map(self.to_pil, images.unbind(dim = 0)))

View File

@@ -528,8 +528,12 @@ class Tracker:
elif save_type == 'model':
if isinstance(trainer, DiffusionPriorTrainer):
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
state_dict = trainer.accelerator.unwrap_model(prior).state_dict()
torch.save(state_dict, file_path)
prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)
# Remove CLIP if it is part of the model
original_clip = prior.clip
prior.clip = None
model_state_dict = prior.state_dict()
prior.clip = original_clip
elif isinstance(trainer, DecoderTrainer):
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
# Remove CLIP if it is part of the model

View File

@@ -1,7 +1,7 @@
import json
from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
from x_clip import CLIP as XCLIP
from coca_pytorch import CoCa
@@ -25,11 +25,9 @@ def exists(val):
def default(val, d):
return val if exists(val) else d
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
def SingularOrIterable(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]
InnerType = TypeVar('InnerType')
ListOrTuple = Union[List[InnerType], Tuple[InnerType]]
SingularOrIterable = Union[InnerType, ListOrTuple[InnerType]]
# general pydantic classes
@@ -222,13 +220,13 @@ class TrainDiffusionPriorConfig(BaseModel):
class UnetConfig(BaseModel):
dim: int
dim_mults: ListOrTuple(int)
dim_mults: ListOrTuple[int]
image_embed_dim: int = None
text_embed_dim: int = None
cond_on_text_encodings: bool = None
cond_dim: int = None
channels: int = 3
self_attn: ListOrTuple(int)
self_attn: ListOrTuple[int]
attn_dim_head: int = 32
attn_heads: int = 16
init_cross_embed: bool = True
@@ -237,16 +235,16 @@ class UnetConfig(BaseModel):
extra = "allow"
class DecoderConfig(BaseModel):
unets: ListOrTuple(UnetConfig)
unets: ListOrTuple[UnetConfig]
image_size: int = None
image_sizes: ListOrTuple(int) = None
image_sizes: ListOrTuple[int] = None
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable(int)] = None
sample_timesteps: Optional[SingularOrIterable[int]] = None
loss_type: str = 'l2'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True
beta_schedule: ListOrTuple[str] = None # None means all cosine
learned_variance: SingularOrIterable[bool] = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5
@@ -305,11 +303,11 @@ class DecoderDataConfig(BaseModel):
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable(float) = 0.01
warmup_steps: Optional[SingularOrIterable(int)] = None
lr: SingularOrIterable[float] = 1e-4
wd: SingularOrIterable[float] = 0.01
warmup_steps: Optional[SingularOrIterable[int]] = None
find_unused_parameters: bool = True
max_grad_norm: SingularOrIterable(float) = 0.5
max_grad_norm: SingularOrIterable[float] = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
cond_scale: Union[float, List[float]] = 1.0
@@ -320,7 +318,7 @@ class DecoderTrainConfig(BaseModel):
use_ema: bool = True
ema_beta: float = 0.999
amp: bool = False
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
unet_training_mask: ListOrTuple[bool] = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000

View File

@@ -174,7 +174,7 @@ class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
accelerator,
accelerator = None,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
@@ -186,8 +186,12 @@ class DiffusionPriorTrainer(nn.Module):
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
assert isinstance(accelerator, Accelerator)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
if not exists(accelerator):
accelerator = Accelerator(**accelerator_kwargs)
# assign some helpful member vars
@@ -300,7 +304,7 @@ class DiffusionPriorTrainer(nn.Module):
# all processes need to load checkpoint. no restriction here
if isinstance(path_or_state, str):
path = Path(path)
path = Path(path_or_state)
assert path.exists()
loaded_obj = torch.load(str(path), map_location=self.device)

View File

@@ -1 +1 @@
__version__ = '1.0.1'
__version__ = '1.4.0'