mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
add diffusion code for the image embedding. nearly all the code is there except for the cascading ddpm in the decoder (with upscaling etc)
This commit is contained in:
@@ -221,8 +221,8 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
image_embed,
|
image_embed,
|
||||||
*,
|
|
||||||
diffusion_timesteps,
|
diffusion_timesteps,
|
||||||
|
*,
|
||||||
text_encodings,
|
text_encodings,
|
||||||
text_embed,
|
text_embed,
|
||||||
mask = None,
|
mask = None,
|
||||||
@@ -272,21 +272,169 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
class DiffusionPrior(nn.Module):
|
class DiffusionPrior(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
net,
|
||||||
*,
|
*,
|
||||||
clip
|
clip,
|
||||||
|
timesteps = 1000,
|
||||||
|
cond_prob_drop = 0.2,
|
||||||
|
loss_type = 'l1'
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
freeze_model_and_make_eval_(clip)
|
freeze_model_and_make_eval_(clip)
|
||||||
|
|
||||||
def forward(
|
self.net = net
|
||||||
self,
|
self.image_embed_dim = clip.dim_latent
|
||||||
*,
|
self.channels = clip.image_channels
|
||||||
text,
|
self.image_size = clip.image_size
|
||||||
image = None
|
self.cond_prob_drop = cond_prob_drop
|
||||||
):
|
|
||||||
|
betas = cosine_beta_schedule(timesteps)
|
||||||
|
|
||||||
|
alphas = 1. - betas
|
||||||
|
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||||
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
|
||||||
|
|
||||||
|
timesteps, = betas.shape
|
||||||
|
self.num_timesteps = int(timesteps)
|
||||||
|
self.loss_type = loss_type
|
||||||
|
|
||||||
|
self.register_buffer('betas', betas)
|
||||||
|
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||||
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||||
|
|
||||||
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||||
|
|
||||||
|
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||||
|
|
||||||
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||||
|
|
||||||
|
self.register_buffer('posterior_variance', posterior_variance)
|
||||||
|
|
||||||
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||||
|
|
||||||
|
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
||||||
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||||
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||||
|
|
||||||
|
def get_image_embed(self, image):
|
||||||
|
image_encoding = self.clip.visual_transformer(image)
|
||||||
|
image_cls = image_encoding[:, 0]
|
||||||
|
image_embed = self.clip.to_visual_latent(image_cls)
|
||||||
return image_embed
|
return image_embed
|
||||||
|
|
||||||
|
def get_text_cond(self, text):
|
||||||
|
text_encodings = self.clip.text_transformer(text)
|
||||||
|
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||||
|
text_embed = self.clip.to_text_latent(text_cls)
|
||||||
|
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
||||||
|
|
||||||
|
def q_mean_variance(self, x_start, t):
|
||||||
|
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||||
|
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
||||||
|
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||||
|
return mean, variance, log_variance
|
||||||
|
|
||||||
|
def predict_start_from_noise(self, x_t, t, noise):
|
||||||
|
return (
|
||||||
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
||||||
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||||
|
)
|
||||||
|
|
||||||
|
def q_posterior(self, x_start, x_t, t):
|
||||||
|
posterior_mean = (
|
||||||
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||||
|
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||||
|
)
|
||||||
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||||
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||||
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
|
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||||
|
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
|
||||||
|
|
||||||
|
if clip_denoised:
|
||||||
|
x_recon.clamp_(-1., 1.)
|
||||||
|
|
||||||
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample(self, x, t, image_embed, text_cond = None, clip_denoised = True, repeat_noise = False):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised)
|
||||||
|
noise = noise_like(x.shape, device, repeat_noise)
|
||||||
|
# no noise when t == 0
|
||||||
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_loop(self, shape, text_cond):
|
||||||
|
device = self.betas.device
|
||||||
|
|
||||||
|
b = shape[0]
|
||||||
|
img = torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||||
|
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||||
|
return img
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self, text):
|
||||||
|
batch_size = text.shape[0]
|
||||||
|
image_embed_dim = self.image_embed_dim
|
||||||
|
|
||||||
|
text_cond = self.get_text_cond(text)
|
||||||
|
return self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||||
|
|
||||||
|
def q_sample(self, x_start, t, noise=None):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
|
return (
|
||||||
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||||
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||||
|
)
|
||||||
|
|
||||||
|
def p_losses(self, image_embed, t, text_cond, noise = None):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
|
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
|
||||||
|
|
||||||
|
x_recon = self.net(
|
||||||
|
image_embed_noisy,
|
||||||
|
t,
|
||||||
|
cond_prob_drop = self.cond_prob_drop,
|
||||||
|
**text_cond
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.loss_type == 'l1':
|
||||||
|
loss = F.l1_loss(noise, x_recon)
|
||||||
|
elif self.loss_type == 'l2':
|
||||||
|
loss = F.mse_loss(noise, x_recon)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def forward(self, text, image, *args, **kwargs):
|
||||||
|
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||||
|
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||||
|
|
||||||
|
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||||
|
image_embed = self.get_image_embed(image)
|
||||||
|
text_cond = self.get_text_cond(text)
|
||||||
|
|
||||||
|
loss = self.p_losses(x, times, image_embed = image_embed, text_cond = text_cond, *args, **kwargs)
|
||||||
|
return loss
|
||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
|
|
||||||
def Upsample(dim):
|
def Upsample(dim):
|
||||||
@@ -428,9 +576,9 @@ class Unet(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
|
time,
|
||||||
*,
|
*,
|
||||||
image_embed,
|
image_embed,
|
||||||
time,
|
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
cond_prob_drop = 0.
|
cond_prob_drop = 0.
|
||||||
):
|
):
|
||||||
|
|||||||
Reference in New Issue
Block a user