add diffusion code for the image embedding. nearly all the code is there except for the cascading ddpm in the decoder (with upscaling etc)

This commit is contained in:
Phil Wang
2022-04-13 10:06:45 -07:00
parent 6d4e9c97bf
commit 791d27326a
2 changed files with 158 additions and 10 deletions

View File

@@ -221,8 +221,8 @@ class DiffusionPriorNetwork(nn.Module):
def forward(
self,
image_embed,
*,
diffusion_timesteps,
*,
text_encodings,
text_embed,
mask = None,
@@ -272,21 +272,169 @@ class DiffusionPriorNetwork(nn.Module):
class DiffusionPrior(nn.Module):
def __init__(
self,
net,
*,
clip
clip,
timesteps = 1000,
cond_prob_drop = 0.2,
loss_type = 'l1'
):
super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
def forward(
self,
*,
text,
image = None
):
self.net = net
self.image_embed_dim = clip.dim_latent
self.channels = clip.image_channels
self.image_size = clip.image_size
self.cond_prob_drop = cond_prob_drop
betas = cosine_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return image_embed
def get_text_cond(self, text):
text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, image_embed, text_cond = None, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, text_cond):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
@torch.no_grad()
def sample(self, text):
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_cond = self.get_text_cond(text)
return self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, image_embed, t, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
x_recon = self.net(
image_embed_noisy,
t,
cond_prob_drop = self.cond_prob_drop,
**text_cond
)
if self.loss_type == 'l1':
loss = F.l1_loss(noise, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(noise, x_recon)
else:
raise NotImplementedError()
return loss
def forward(self, text, image, *args, **kwargs):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
text_cond = self.get_text_cond(text)
loss = self.p_losses(x, times, image_embed = image_embed, text_cond = text_cond, *args, **kwargs)
return loss
# decoder
def Upsample(dim):
@@ -428,9 +576,9 @@ class Unet(nn.Module):
def forward(
self,
x,
time,
*,
image_embed,
time,
text_encodings = None,
cond_prob_drop = 0.
):