|
|
|
|
@@ -12,10 +12,8 @@ from torch.utils.checkpoint import checkpoint
|
|
|
|
|
from torch import nn, einsum
|
|
|
|
|
import torchvision.transforms as T
|
|
|
|
|
|
|
|
|
|
from einops import rearrange, repeat, reduce
|
|
|
|
|
from einops import rearrange, repeat, reduce, pack, unpack
|
|
|
|
|
from einops.layers.torch import Rearrange
|
|
|
|
|
from einops_exts import rearrange_many, repeat_many, check_shape
|
|
|
|
|
from einops_exts.torch import EinopsToAndFrom
|
|
|
|
|
|
|
|
|
|
from kornia.filters import gaussian_blur2d
|
|
|
|
|
import kornia.augmentation as K
|
|
|
|
|
@@ -669,6 +667,23 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
return loss
|
|
|
|
|
return loss * extract(self.p2_loss_weight, times, loss.shape)
|
|
|
|
|
|
|
|
|
|
# rearrange image to sequence
|
|
|
|
|
|
|
|
|
|
class RearrangeToSequence(nn.Module):
|
|
|
|
|
def __init__(self, fn):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.fn = fn
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = rearrange(x, 'b c ... -> b ... c')
|
|
|
|
|
x, ps = pack([x], 'b * c')
|
|
|
|
|
|
|
|
|
|
x = self.fn(x)
|
|
|
|
|
|
|
|
|
|
x, = unpack(x, ps, 'b * c')
|
|
|
|
|
x = rearrange(x, 'b ... c -> b c ...')
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
# diffusion prior
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module):
|
|
|
|
|
@@ -867,7 +882,7 @@ class Attention(nn.Module):
|
|
|
|
|
|
|
|
|
|
# add null key / value for classifier free guidance in prior net
|
|
|
|
|
|
|
|
|
|
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
|
|
|
|
|
nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
|
|
|
|
|
k = torch.cat((nk, k), dim = -2)
|
|
|
|
|
v = torch.cat((nv, v), dim = -2)
|
|
|
|
|
|
|
|
|
|
@@ -1124,7 +1139,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
|
|
|
|
|
|
|
|
|
if self.self_cond:
|
|
|
|
|
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
|
|
|
|
|
learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
|
|
|
|
|
|
|
|
|
|
tokens = torch.cat((
|
|
|
|
|
text_encodings,
|
|
|
|
|
@@ -1334,10 +1349,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
# predict noise
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start or self.predict_v:
|
|
|
|
|
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
|
|
|
|
|
else:
|
|
|
|
|
pred_noise = pred
|
|
|
|
|
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
|
|
|
|
|
|
|
|
|
|
if time_next < 0:
|
|
|
|
|
image_embed = x_start
|
|
|
|
|
@@ -1632,14 +1644,10 @@ class ResnetBlock(nn.Module):
|
|
|
|
|
self.cross_attn = None
|
|
|
|
|
|
|
|
|
|
if exists(cond_dim):
|
|
|
|
|
self.cross_attn = EinopsToAndFrom(
|
|
|
|
|
'b c h w',
|
|
|
|
|
'b (h w) c',
|
|
|
|
|
CrossAttention(
|
|
|
|
|
dim = dim_out,
|
|
|
|
|
context_dim = cond_dim,
|
|
|
|
|
cosine_sim = cosine_sim_cross_attn
|
|
|
|
|
)
|
|
|
|
|
self.cross_attn = CrossAttention(
|
|
|
|
|
dim = dim_out,
|
|
|
|
|
context_dim = cond_dim,
|
|
|
|
|
cosine_sim = cosine_sim_cross_attn
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
|
|
|
|
|
@@ -1658,8 +1666,15 @@ class ResnetBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
if exists(self.cross_attn):
|
|
|
|
|
assert exists(cond)
|
|
|
|
|
|
|
|
|
|
h = rearrange(h, 'b c ... -> b ... c')
|
|
|
|
|
h, ps = pack([h], 'b * c')
|
|
|
|
|
|
|
|
|
|
h = self.cross_attn(h, context = cond) + h
|
|
|
|
|
|
|
|
|
|
h, = unpack(h, ps, 'b * c')
|
|
|
|
|
h = rearrange(h, 'b ... c -> b c ...')
|
|
|
|
|
|
|
|
|
|
h = self.block2(h)
|
|
|
|
|
return h + self.res_conv(x)
|
|
|
|
|
|
|
|
|
|
@@ -1705,11 +1720,11 @@ class CrossAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
|
|
|
|
|
|
|
|
|
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
|
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
|
|
|
|
|
|
|
|
|
|
# add null key / value for classifier free guidance in prior net
|
|
|
|
|
|
|
|
|
|
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
|
|
|
|
|
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
|
|
|
|
|
|
|
|
|
|
k = torch.cat((nk, k), dim = -2)
|
|
|
|
|
v = torch.cat((nv, v), dim = -2)
|
|
|
|
|
@@ -1762,7 +1777,7 @@ class LinearAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
fmap = self.norm(fmap)
|
|
|
|
|
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
|
|
|
|
|
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
|
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
|
|
|
|
|
|
|
|
|
|
q = q.softmax(dim = -1)
|
|
|
|
|
k = k.softmax(dim = -2)
|
|
|
|
|
@@ -1996,7 +2011,7 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
self_attn = cast_tuple(self_attn, num_stages)
|
|
|
|
|
|
|
|
|
|
create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs)))
|
|
|
|
|
create_self_attn = lambda dim: RearrangeToSequence(Residual(Attention(dim, **attn_kwargs)))
|
|
|
|
|
|
|
|
|
|
# resnet block klass
|
|
|
|
|
|
|
|
|
|
@@ -2730,11 +2745,16 @@ class Decoder(nn.Module):
|
|
|
|
|
if exists(unet_number):
|
|
|
|
|
unet = self.get_unet(unet_number)
|
|
|
|
|
|
|
|
|
|
# devices
|
|
|
|
|
|
|
|
|
|
cuda, cpu = torch.device('cuda'), torch.device('cpu')
|
|
|
|
|
|
|
|
|
|
self.cuda()
|
|
|
|
|
|
|
|
|
|
devices = [module_device(unet) for unet in self.unets]
|
|
|
|
|
self.unets.cpu()
|
|
|
|
|
unet.cuda()
|
|
|
|
|
|
|
|
|
|
self.unets.to(cpu)
|
|
|
|
|
unet.to(cuda)
|
|
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
@@ -2975,10 +2995,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
# predict noise
|
|
|
|
|
|
|
|
|
|
if predict_x_start or predict_v:
|
|
|
|
|
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
|
|
|
|
|
else:
|
|
|
|
|
pred_noise = pred
|
|
|
|
|
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
|
|
|
|
|
|
|
|
|
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
|
|
|
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
|
|
|
|
@@ -3120,7 +3137,8 @@ class Decoder(nn.Module):
|
|
|
|
|
distributed = False,
|
|
|
|
|
inpaint_image = None,
|
|
|
|
|
inpaint_mask = None,
|
|
|
|
|
inpaint_resample_times = 5
|
|
|
|
|
inpaint_resample_times = 5,
|
|
|
|
|
one_unet_in_gpu_at_time = True
|
|
|
|
|
):
|
|
|
|
|
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
|
|
|
|
|
|
|
|
|
|
@@ -3143,6 +3161,7 @@ class Decoder(nn.Module):
|
|
|
|
|
assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)
|
|
|
|
|
prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]
|
|
|
|
|
img = resize_image_to(image, prev_unet_output_size, nearest = True)
|
|
|
|
|
|
|
|
|
|
is_cuda = next(self.parameters()).is_cuda
|
|
|
|
|
|
|
|
|
|
num_unets = self.num_unets
|
|
|
|
|
@@ -3152,7 +3171,7 @@ class Decoder(nn.Module):
|
|
|
|
|
if unet_number < start_at_unet_number:
|
|
|
|
|
continue # It's the easiest way to do it
|
|
|
|
|
|
|
|
|
|
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
|
|
|
|
context = self.one_unet_in_gpu(unet = unet) if is_cuda and one_unet_in_gpu_at_time else null_context()
|
|
|
|
|
|
|
|
|
|
with context:
|
|
|
|
|
# prepare low resolution conditioning for upsamplers
|
|
|
|
|
@@ -3229,7 +3248,7 @@ class Decoder(nn.Module):
|
|
|
|
|
learned_variance = self.learned_variance[unet_index]
|
|
|
|
|
b, c, h, w, device, = *image.shape, image.device
|
|
|
|
|
|
|
|
|
|
check_shape(image, 'b c h w', c = self.channels)
|
|
|
|
|
assert image.shape[1] == self.channels
|
|
|
|
|
assert h >= target_image_size and w >= target_image_size
|
|
|
|
|
|
|
|
|
|
times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
|
|
|
|
|
|