Compare commits

...

12 Commits
1.0.1 ... 1.4.0

Author SHA1 Message Date
Phil Wang
748c7fe7af allow for cosine sim cross attention, modify linear attention in attempt to resolve issue on fp16 2022-07-29 11:12:18 -07:00
Phil Wang
80046334ad make sure entire readme runs without errors 2022-07-28 10:17:43 -07:00
Phil Wang
36fb46a95e fix readme and a small bug in DALLE2 class 2022-07-28 08:33:51 -07:00
Phil Wang
07abfcf45b rescale values in linear attention to mitigate overflows in fp16 setting 2022-07-27 12:27:38 -07:00
Phil Wang
2e35a9967d product management 2022-07-26 11:10:16 -07:00
Phil Wang
406e75043f add upsample combiner feature for the unets 2022-07-26 10:46:04 -07:00
Phil Wang
9646dfc0e6 fix path_or_state bug 2022-07-26 09:47:54 -07:00
Phil Wang
62043acb2f fix repaint 2022-07-24 15:29:06 -07:00
Phil Wang
417ff808e6 1.0.3 2022-07-22 13:16:57 -07:00
Aidan Dempster
f3d7e226ba Changed types to be generic instead of functions (#215)
This allows pylance to do proper type hinting and makes developing
extensions to the package much easier
2022-07-22 13:16:29 -07:00
Phil Wang
48a1302428 1.0.2 2022-07-20 23:01:51 -07:00
Aidan Dempster
ccaa46b81b Re-introduced change that was accidentally rolled back (#212) 2022-07-20 23:01:19 -07:00
6 changed files with 220 additions and 110 deletions

View File

@@ -371,6 +371,7 @@ loss.backward()
unet1 = Unet( unet1 = Unet(
dim = 128, dim = 128,
image_embed_dim = 512, image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
@@ -395,7 +396,7 @@ decoder = Decoder(
).cuda() ).cuda()
for unet_number in (1, 2): for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward() loss.backward()
# do above for many steps # do above for many steps
@@ -860,25 +861,23 @@ unet1 = Unet(
text_embed_dim = 512, text_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults=(1, 2, 4, 8) dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True,
).cuda() ).cuda()
unet2 = Unet( unet2 = Unet(
dim = 16, dim = 16,
image_embed_dim = 512, image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults = (1, 2, 4, 8, 16), dim_mults = (1, 2, 4, 8, 16),
cond_on_text_encodings = True
).cuda() ).cuda()
decoder = Decoder( decoder = Decoder(
unet = (unet1, unet2), unet = (unet1, unet2),
image_sizes = (128, 256), image_sizes = (128, 256),
clip = clip, clip = clip,
timesteps = 1000, timesteps = 1000
condition_on_text_encodings = True
).cuda() ).cuda()
decoder_trainer = DecoderTrainer( decoder_trainer = DecoderTrainer(
@@ -903,8 +902,8 @@ for unet_number in (1, 2):
# after much training # after much training
# you can sample from the exponentially moving averaged unets as so # you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda() mock_image_embed = torch.randn(32, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256) images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
``` ```
### Diffusion Prior Training ### Diffusion Prior Training
@@ -1112,7 +1111,8 @@ For detailed information on training the diffusion prior, please refer to the [d
- [x] allow for unet to be able to condition non-cross attention style as well - [x] allow for unet to be able to condition non-cross attention style as well
- [x] speed up inference, read up on papers (ddim) - [x] speed up inference, read up on papers (ddim)
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865 - [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow - [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
## Citations ## Citations

View File

@@ -516,6 +516,17 @@ class NoiseScheduler(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
) )
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
shape = x_from.shape
noise = default(noise, lambda: torch.randn_like(x_from))
alpha = extract(self.sqrt_alphas_cumprod, from_t, shape)
sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape)
alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape)
sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape)
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
def predict_start_from_noise(self, x_t, t, noise): def predict_start_from_noise(self, x_t, t, noise):
return ( return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
@@ -1346,7 +1357,8 @@ class ResnetBlock(nn.Module):
*, *,
cond_dim = None, cond_dim = None,
time_cond_dim = None, time_cond_dim = None,
groups = 8 groups = 8,
cosine_sim_cross_attn = False
): ):
super().__init__() super().__init__()
@@ -1366,7 +1378,8 @@ class ResnetBlock(nn.Module):
'b (h w) c', 'b (h w) c',
CrossAttention( CrossAttention(
dim = dim_out, dim = dim_out,
context_dim = cond_dim context_dim = cond_dim,
cosine_sim = cosine_sim_cross_attn
) )
) )
@@ -1401,11 +1414,12 @@ class CrossAttention(nn.Module):
heads = 8, heads = 8,
dropout = 0., dropout = 0.,
norm_context = False, norm_context = False,
pb_relax_alpha = 32 ** 2 cosine_sim = False,
cosine_sim_scale = 16
): ):
super().__init__() super().__init__()
self.pb_relax_alpha = pb_relax_alpha self.cosine_sim = cosine_sim
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1) self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
self.heads = heads self.heads = heads
inner_dim = dim_head * heads inner_dim = dim_head * heads
@@ -1441,7 +1455,10 @@ class CrossAttention(nn.Module):
k = torch.cat((nk, k), dim = -2) k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2) v = torch.cat((nv, v), dim = -2)
q = q * self.scale if self.cosine_sim:
q, k = map(l2norm, (q, k))
q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
sim = einsum('b h i d, b h j d -> b h i j', q, k) sim = einsum('b h i d, b h j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max
@@ -1451,9 +1468,6 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j') mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value) sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1) attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v) out = einsum('b h i j, b h j d -> b h i d', attn, v)
@@ -1483,6 +1497,7 @@ class LinearAttention(nn.Module):
def forward(self, fmap): def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:] h, x, y = self.heads, *fmap.shape[-2:]
seq_len = x * y
fmap = self.norm(fmap) fmap = self.norm(fmap)
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1) q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
@@ -1492,6 +1507,9 @@ class LinearAttention(nn.Module):
k = k.softmax(dim = -2) k = k.softmax(dim = -2)
q = q * self.scale q = q * self.scale
v = l2norm(v)
k, v = map(lambda t: t / math.sqrt(seq_len), (k, v))
context = einsum('b n d, b n e -> b d e', k, v) context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context) out = einsum('b n d, b d e -> b n e', q, context)
@@ -1527,6 +1545,38 @@ class CrossEmbedLayer(nn.Module):
fmaps = tuple(map(lambda conv: conv(x), self.convs)) fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1) return torch.cat(fmaps, dim = 1)
class UpsampleCombiner(nn.Module):
def __init__(
self,
dim,
*,
enabled = False,
dim_ins = tuple(),
dim_outs = tuple()
):
super().__init__()
assert len(dim_ins) == len(dim_outs)
self.enabled = enabled
if not self.enabled:
self.dim_out = dim
return
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
def forward(self, x, fmaps = None):
target_size = x.shape[-1]
fmaps = default(fmaps, tuple())
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
return x
fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
return torch.cat((x, *outs), dim = 1)
class Unet(nn.Module): class Unet(nn.Module):
def __init__( def __init__(
self, self,
@@ -1547,6 +1597,7 @@ class Unet(nn.Module):
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
sparse_attn = False, sparse_attn = False,
cosine_sim_cross_attn = False,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False, cond_on_text_encodings = False,
max_text_len = 256, max_text_len = 256,
@@ -1564,6 +1615,7 @@ class Unet(nn.Module):
scale_skip_connection = False, scale_skip_connection = False,
pixel_shuffle_upsample = True, pixel_shuffle_upsample = True,
final_conv_kernel_size = 1, final_conv_kernel_size = 1,
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -1689,9 +1741,13 @@ class Unet(nn.Module):
upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# prepare resnet klass
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn)
# give memory efficient unet an initial resnet block # give memory efficient unet an initial resnet block
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None self.init_resnet_block = resnet_block(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
# layers # layers
@@ -1699,7 +1755,8 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([]) self.ups = nn.ModuleList([])
num_resolutions = len(in_out) num_resolutions = len(in_out)
skip_connect_dims = [] # keeping track of skip connection dimensions skip_connect_dims = [] # keeping track of skip connection dimensions
upsample_combiner_dims = [] # keeping track of dimensions for final upsample feature map combiner
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)): for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
is_first = ind == 0 is_first = ind == 0
@@ -1717,17 +1774,17 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None, downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups), resnet_block(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), nn.ModuleList([resnet_block(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention, attention,
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1) downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
])) ]))
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) self.mid_block1 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = create_self_attn(mid_dim) self.mid_attn = create_self_attn(mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) self.mid_block2 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))): for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
is_last = ind >= (len(in_out) - 1) is_last = ind >= (len(in_out) - 1)
@@ -1741,14 +1798,27 @@ class Unet(nn.Module):
elif sparse_attn: elif sparse_attn:
attention = Residual(LinearAttention(dim_out, **attn_kwargs)) attention = Residual(LinearAttention(dim_out, **attn_kwargs))
upsample_combiner_dims.append(dim_out)
self.ups.append(nn.ModuleList([ self.ups.append(nn.ModuleList([
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), nn.ModuleList([resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention, attention,
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity() upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
])) ]))
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) # whether to combine outputs from all upsample blocks for final resnet block
self.upsample_combiner = UpsampleCombiner(
dim = dim,
enabled = combine_upsample_fmaps,
dim_ins = upsample_combiner_dims,
dim_outs = (dim,) * len(upsample_combiner_dims)
)
# a final resnet block
self.final_resnet_block = resnet_block(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
out_dim_in = dim + (channels if lowres_cond else 0) out_dim_in = dim + (channels if lowres_cond else 0)
@@ -1772,7 +1842,7 @@ class Unet(nn.Module):
channels == self.channels and \ channels == self.channels and \
cond_on_image_embeds == self.cond_on_image_embeds and \ cond_on_image_embeds == self.cond_on_image_embeds and \
cond_on_text_encodings == self.cond_on_text_encodings and \ cond_on_text_encodings == self.cond_on_text_encodings and \
cond_on_lowres_noise == self.cond_on_lowres_noise and \ lowres_noise_cond == self.lowres_noise_cond and \
channels_out == self.channels_out: channels_out == self.channels_out:
return self return self
@@ -1942,7 +2012,8 @@ class Unet(nn.Module):
# go through the layers of the unet, down and up # go through the layers of the unet, down and up
hiddens = [] down_hiddens = []
up_hiddens = []
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs: for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
if exists(pre_downsample): if exists(pre_downsample):
@@ -1952,10 +2023,10 @@ class Unet(nn.Module):
for resnet_block in resnet_blocks: for resnet_block in resnet_blocks:
x = resnet_block(x, t, c) x = resnet_block(x, t, c)
hiddens.append(x) down_hiddens.append(x.contiguous())
x = attn(x) x = attn(x)
hiddens.append(x.contiguous()) down_hiddens.append(x.contiguous())
if exists(post_downsample): if exists(post_downsample):
x = post_downsample(x) x = post_downsample(x)
@@ -1967,7 +2038,7 @@ class Unet(nn.Module):
x = self.mid_block2(x, t, mid_c) x = self.mid_block2(x, t, mid_c)
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1) connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
for init_block, resnet_blocks, attn, upsample in self.ups: for init_block, resnet_blocks, attn, upsample in self.ups:
x = connect_skip(x) x = connect_skip(x)
@@ -1978,8 +2049,12 @@ class Unet(nn.Module):
x = resnet_block(x, t, c) x = resnet_block(x, t, c)
x = attn(x) x = attn(x)
up_hiddens.append(x.contiguous())
x = upsample(x) x = upsample(x)
x = self.upsample_combiner(x, up_hiddens)
x = torch.cat((x, r), dim = 1) x = torch.cat((x, r), dim = 1)
x = self.final_resnet_block(x, t) x = self.final_resnet_block(x, t)
@@ -2432,14 +2507,18 @@ class Decoder(nn.Module):
is_latent_diffusion = False, is_latent_diffusion = False,
lowres_noise_level = None, lowres_noise_level = None,
inpaint_image = None, inpaint_image = None,
inpaint_mask = None inpaint_mask = None,
inpaint_resample_times = 5
): ):
device = self.device device = self.device
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
if exists(inpaint_image): is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1
if is_inpaint:
inpaint_image = self.normalize_img(inpaint_image) inpaint_image = self.normalize_img(inpaint_image)
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True) inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float() inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
@@ -2449,31 +2528,40 @@ class Decoder(nn.Module):
if not is_latent_diffusion: if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps): for time in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long) is_last_timestep = time == 0
if exists(inpaint_image): for r in reversed(range(0, resample_times)):
# following the repaint paper is_last_resample_step = r == 0
# https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
img = self.p_sample( times = torch.full((b,), time, device = device, dtype = torch.long)
unet,
img,
times,
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start,
noise_scheduler = noise_scheduler,
learned_variance = learned_variance,
clip_denoised = clip_denoised
)
if exists(inpaint_image): if is_inpaint:
# following the repaint paper
# https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
img = self.p_sample(
unet,
img,
times,
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start,
noise_scheduler = noise_scheduler,
learned_variance = learned_variance,
clip_denoised = clip_denoised
)
if is_inpaint and not (is_last_timestep or is_last_resample_step):
# in repaint, you renoise and resample up to 10 times every step
img = noise_scheduler.q_sample_from_to(img, times - 1, times)
if is_inpaint:
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask) img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
unnormalize_img = self.unnormalize_img(img) unnormalize_img = self.unnormalize_img(img)
@@ -2497,7 +2585,8 @@ class Decoder(nn.Module):
is_latent_diffusion = False, is_latent_diffusion = False,
lowres_noise_level = None, lowres_noise_level = None,
inpaint_image = None, inpaint_image = None,
inpaint_mask = None inpaint_mask = None,
inpaint_resample_times = 5
): ):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
@@ -2506,7 +2595,10 @@ class Decoder(nn.Module):
times = list(reversed(times.int().tolist())) times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) time_pairs = list(zip(times[:-1], times[1:]))
if exists(inpaint_image): is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1
if is_inpaint:
inpaint_image = self.normalize_img(inpaint_image) inpaint_image = self.normalize_img(inpaint_image)
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True) inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float() inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
@@ -2519,39 +2611,49 @@ class Decoder(nn.Module):
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time] is_last_timestep = time_next == 0
alpha_next = alphas[time_next]
time_cond = torch.full((batch,), time, device = device, dtype = torch.long) for r in reversed(range(0, resample_times)):
is_last_resample_step = r == 0
if exists(inpaint_image): alpha = alphas[time]
# following the repaint paper alpha_next = alphas[time_next]
# https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level) time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
if learned_variance: if is_inpaint:
pred, _ = pred.chunk(2, dim = 1) # following the repaint paper
# https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
if predict_x_start: pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
if clip_denoised: if learned_variance:
x_start = self.dynamic_threshold(x_start) pred, _ = pred.chunk(2, dim = 1)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() if predict_x_start:
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() x_start = pred
noise = torch.randn_like(img) if time_next > 0 else 0. pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
img = x_start * alpha_next.sqrt() + \ if clip_denoised:
c1 * noise + \ x_start = self.dynamic_threshold(x_start)
c2 * pred_noise
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if not is_last_timestep else 0.
img = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
if is_inpaint and not (is_last_timestep or is_last_resample_step):
# in repaint, you renoise and resample up to 10 times every step
time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)
img = noise_scheduler.q_sample_from_to(img, time_next_cond, time_cond)
if exists(inpaint_image): if exists(inpaint_image):
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask) img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
@@ -2658,7 +2760,8 @@ class Decoder(nn.Module):
stop_at_unet_number = None, stop_at_unet_number = None,
distributed = False, distributed = False,
inpaint_image = None, inpaint_image = None,
inpaint_mask = None inpaint_mask = None,
inpaint_resample_times = 5
): ):
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally' assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
@@ -2730,7 +2833,8 @@ class Decoder(nn.Module):
noise_scheduler = noise_scheduler, noise_scheduler = noise_scheduler,
timesteps = sample_timesteps, timesteps = sample_timesteps,
inpaint_image = inpaint_image, inpaint_image = inpaint_image,
inpaint_mask = inpaint_mask inpaint_mask = inpaint_mask,
inpaint_resample_times = inpaint_resample_times
) )
img = vae.decode(img) img = vae.decode(img)
@@ -2845,7 +2949,7 @@ class DALLE2(nn.Module):
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale) image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
text_cond = text if self.decoder_need_text_cond else None text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale) images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)
if return_pil_images: if return_pil_images:
images = list(map(self.to_pil, images.unbind(dim = 0))) images = list(map(self.to_pil, images.unbind(dim = 0)))

View File

@@ -528,8 +528,12 @@ class Tracker:
elif save_type == 'model': elif save_type == 'model':
if isinstance(trainer, DiffusionPriorTrainer): if isinstance(trainer, DiffusionPriorTrainer):
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
state_dict = trainer.accelerator.unwrap_model(prior).state_dict() prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)
torch.save(state_dict, file_path) # Remove CLIP if it is part of the model
original_clip = prior.clip
prior.clip = None
model_state_dict = prior.state_dict()
prior.clip = original_clip
elif isinstance(trainer, DecoderTrainer): elif isinstance(trainer, DecoderTrainer):
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder) decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
# Remove CLIP if it is part of the model # Remove CLIP if it is part of the model

View File

@@ -1,7 +1,7 @@
import json import json
from torchvision import transforms as T from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
from x_clip import CLIP as XCLIP from x_clip import CLIP as XCLIP
from coca_pytorch import CoCa from coca_pytorch import CoCa
@@ -25,11 +25,9 @@ def exists(val):
def default(val, d): def default(val, d):
return val if exists(val) else d return val if exists(val) else d
def ListOrTuple(inner_type): InnerType = TypeVar('InnerType')
return Union[List[inner_type], Tuple[inner_type]] ListOrTuple = Union[List[InnerType], Tuple[InnerType]]
SingularOrIterable = Union[InnerType, ListOrTuple[InnerType]]
def SingularOrIterable(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]
# general pydantic classes # general pydantic classes
@@ -222,13 +220,13 @@ class TrainDiffusionPriorConfig(BaseModel):
class UnetConfig(BaseModel): class UnetConfig(BaseModel):
dim: int dim: int
dim_mults: ListOrTuple(int) dim_mults: ListOrTuple[int]
image_embed_dim: int = None image_embed_dim: int = None
text_embed_dim: int = None text_embed_dim: int = None
cond_on_text_encodings: bool = None cond_on_text_encodings: bool = None
cond_dim: int = None cond_dim: int = None
channels: int = 3 channels: int = 3
self_attn: ListOrTuple(int) self_attn: ListOrTuple[int]
attn_dim_head: int = 32 attn_dim_head: int = 32
attn_heads: int = 16 attn_heads: int = 16
init_cross_embed: bool = True init_cross_embed: bool = True
@@ -237,16 +235,16 @@ class UnetConfig(BaseModel):
extra = "allow" extra = "allow"
class DecoderConfig(BaseModel): class DecoderConfig(BaseModel):
unets: ListOrTuple(UnetConfig) unets: ListOrTuple[UnetConfig]
image_size: int = None image_size: int = None
image_sizes: ListOrTuple(int) = None image_sizes: ListOrTuple[int] = None
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3 channels: int = 3
timesteps: int = 1000 timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable(int)] = None sample_timesteps: Optional[SingularOrIterable[int]] = None
loss_type: str = 'l2' loss_type: str = 'l2'
beta_schedule: ListOrTuple(str) = 'cosine' beta_schedule: ListOrTuple[str] = None # None means all cosine
learned_variance: bool = True learned_variance: SingularOrIterable[bool] = True
image_cond_drop_prob: float = 0.1 image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5 text_cond_drop_prob: float = 0.5
@@ -305,11 +303,11 @@ class DecoderDataConfig(BaseModel):
class DecoderTrainConfig(BaseModel): class DecoderTrainConfig(BaseModel):
epochs: int = 20 epochs: int = 20
lr: SingularOrIterable(float) = 1e-4 lr: SingularOrIterable[float] = 1e-4
wd: SingularOrIterable(float) = 0.01 wd: SingularOrIterable[float] = 0.01
warmup_steps: Optional[SingularOrIterable(int)] = None warmup_steps: Optional[SingularOrIterable[int]] = None
find_unused_parameters: bool = True find_unused_parameters: bool = True
max_grad_norm: SingularOrIterable(float) = 0.5 max_grad_norm: SingularOrIterable[float] = 0.5
save_every_n_samples: int = 100000 save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
cond_scale: Union[float, List[float]] = 1.0 cond_scale: Union[float, List[float]] = 1.0
@@ -320,7 +318,7 @@ class DecoderTrainConfig(BaseModel):
use_ema: bool = True use_ema: bool = True
ema_beta: float = 0.999 ema_beta: float = 0.999
amp: bool = False amp: bool = False
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets unet_training_mask: ListOrTuple[bool] = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel): class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000 n_evaluation_samples: int = 1000

View File

@@ -174,7 +174,7 @@ class DiffusionPriorTrainer(nn.Module):
def __init__( def __init__(
self, self,
diffusion_prior, diffusion_prior,
accelerator, accelerator = None,
use_ema = True, use_ema = True,
lr = 3e-4, lr = 3e-4,
wd = 1e-2, wd = 1e-2,
@@ -186,8 +186,12 @@ class DiffusionPriorTrainer(nn.Module):
): ):
super().__init__() super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior) assert isinstance(diffusion_prior, DiffusionPrior)
assert isinstance(accelerator, Accelerator)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
if not exists(accelerator):
accelerator = Accelerator(**accelerator_kwargs)
# assign some helpful member vars # assign some helpful member vars
@@ -300,7 +304,7 @@ class DiffusionPriorTrainer(nn.Module):
# all processes need to load checkpoint. no restriction here # all processes need to load checkpoint. no restriction here
if isinstance(path_or_state, str): if isinstance(path_or_state, str):
path = Path(path) path = Path(path_or_state)
assert path.exists() assert path.exists()
loaded_obj = torch.load(str(path), map_location=self.device) loaded_obj = torch.load(str(path), map_location=self.device)

View File

@@ -1 +1 @@
__version__ = '1.0.1' __version__ = '1.4.0'