Compare commits

..

9 Commits

3 changed files with 60 additions and 48 deletions

View File

@@ -7,6 +7,7 @@ from contextlib import contextmanager
import torch
import torch.nn.functional as F
from torch import nn, einsum
import torchvision.transforms as T
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
@@ -222,7 +223,18 @@ class BaseGaussianDiffusion(nn.Module):
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
if loss_type == 'l1':
loss_fn = F.l1_loss
elif loss_type == 'l2':
loss_fn = F.mse_loss
elif loss_type == 'huber':
loss_fn = F.smooth_l1_loss
else:
raise NotImplementedError()
self.loss_type = loss_type
self.loss_fn = loss_fn
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
@@ -646,9 +658,12 @@ class DiffusionPrior(BaseGaussianDiffusion):
)
if exists(clip):
assert isinstance(clip, CLIP)
if isinstance(clip, CLIP):
clip = XClipAdapter(clip)
assert isinstance(clip, BaseClipAdapter)
freeze_model_and_make_eval_(clip)
self.clip = XClipAdapter(clip)
self.clip = clip
else:
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
self.clip = None
@@ -699,29 +714,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
def p_losses(self, image_embed, t, text_cond, noise = None):
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
x_recon = self.net(
pred = self.net(
image_embed_noisy,
t,
times,
cond_drop_prob = self.cond_drop_prob,
**text_cond
)
to_predict = noise if not self.predict_x_start else image_embed
if self.loss_type == 'l1':
loss = F.l1_loss(to_predict, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(to_predict, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(to_predict, x_recon)
else:
raise NotImplementedError()
target = noise if not self.predict_x_start else image_embed
loss = self.loss_fn(pred, target)
return loss
@torch.no_grad()
@@ -736,11 +743,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
text_embed, text_encodings = self.clip.embed_text(text)
text_cond = dict(
text_embed = text_embed,
text_encodings = text_encodings,
mask = text != 0
)
text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text != 0}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed']
@@ -780,11 +786,11 @@ class DiffusionPrior(BaseGaussianDiffusion):
text_embed, text_encodings = self.clip.embed_text(text)
text_mask = text != 0
text_cond = dict(
text_embed = text_embed,
text_encodings = text_encodings,
mask = text_mask
)
text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings:
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
# timestep conditioning from ddpm
@@ -793,8 +799,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# calculate forward loss
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
return loss
return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
# decoder
@@ -1064,13 +1069,14 @@ class Unet(nn.Module):
self,
*,
lowres_cond,
channels
channels,
cond_on_image_embeds
):
if lowres_cond == self.lowres_cond and channels == self.channels:
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
return self.__class__(**updated_kwargs)
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale(
self,
@@ -1207,7 +1213,7 @@ class LowresConditioner(nn.Module):
target_image_size = cast_tuple(target_image_size, 2)
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, mode = self.cond_upsample_mode)
if self.training:
# when training, blur the low resolution conditional image
@@ -1249,6 +1255,8 @@ class Decoder(BaseGaussianDiffusion):
clip = XClipAdapter(clip)
freeze_model_and_make_eval_(clip)
assert isinstance(clip, BaseClipAdapter)
self.clip = clip
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
@@ -1275,6 +1283,7 @@ class Decoder(BaseGaussianDiffusion):
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
cond_on_image_embeds = is_first,
channels = unet_channels
)
@@ -1382,14 +1391,14 @@ class Decoder(BaseGaussianDiffusion):
return img
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
x_recon = unet(
pred = unet(
x_noisy,
t,
times,
image_embed = image_embed,
text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img,
@@ -1398,15 +1407,7 @@ class Decoder(BaseGaussianDiffusion):
target = noise if not predict_x_start else x_start
if self.loss_type == 'l1':
loss = F.l1_loss(target, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(target, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(target, x_recon)
else:
raise NotImplementedError()
loss = self.loss_fn(pred, target)
return loss
@torch.no_grad()
@@ -1419,6 +1420,7 @@ class Decoder(BaseGaussianDiffusion):
_, text_encodings = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
img = None
@@ -1486,6 +1488,7 @@ class Decoder(BaseGaussianDiffusion):
_, text_encodings = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)
@@ -1518,12 +1521,15 @@ class DALLE2(nn.Module):
self.prior_num_samples = prior_num_samples
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
self.to_pil = T.ToPILImage()
@torch.no_grad()
@eval_decorator
def forward(
self,
text,
cond_scale = 1.
cond_scale = 1.,
return_pil_images = False
):
device = next(self.parameters()).device
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
@@ -1537,7 +1543,11 @@ class DALLE2(nn.Module):
text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
if return_pil_images:
images = list(map(self.to_pil, images.unbind(dim = 0)))
if one_text:
return images[0]
return images

View File

@@ -545,6 +545,7 @@ class VQGanVAE(nn.Module):
l2_recon_loss = False,
use_hinge_loss = True,
vgg = None,
vq_codebook_dim = 256,
vq_codebook_size = 512,
vq_decay = 0.8,
vq_commitment_weight = 1.,
@@ -579,6 +580,7 @@ class VQGanVAE(nn.Module):
self.vq = VQ(
dim = self.enc_dec.encoded_dim,
codebook_dim = vq_codebook_dim,
codebook_size = vq_codebook_size,
decay = vq_decay,
commitment_weight = vq_commitment_weight,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.57',
version = '0.0.65',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',