Compare commits

...

3 Commits

4 changed files with 95 additions and 16 deletions

View File

@@ -61,6 +61,9 @@ def default(val, d):
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def module_device(module):
return next(module.parameters()).device
@contextmanager
def null_context(*args, **kwargs):
yield
@@ -1817,7 +1820,7 @@ class Decoder(BaseGaussianDiffusion):
self.cuda()
devices = [next(unet.parameters()).device for unet in self.unets]
devices = [module_device(unet) for unet in self.unets]
self.unets.cpu()
unet.cuda()
@@ -1867,13 +1870,14 @@ class Decoder(BaseGaussianDiffusion):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device = device)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
if not is_latent_diffusion:
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(
@@ -1893,13 +1897,14 @@ class Decoder(BaseGaussianDiffusion):
unnormalize_img = unnormalize_zero_to_one(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
x_start = normalize_neg_one_to_one(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
if not is_latent_diffusion:
x_start = normalize_neg_one_to_one(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
# get x_t
@@ -1965,6 +1970,8 @@ class Decoder(BaseGaussianDiffusion):
self,
image_embed = None,
text = None,
text_mask = None,
text_encodings = None,
batch_size = 1,
cond_scale = 1.,
stop_at_unet_number = None
@@ -1974,8 +1981,8 @@ class Decoder(BaseGaussianDiffusion):
if not self.unconditional:
batch_size = image_embed.shape[0]
text_encodings = text_mask = None
if exists(text):
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip)
_, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
@@ -2011,7 +2018,8 @@ class Decoder(BaseGaussianDiffusion):
predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img
lowres_cond_img = lowres_cond_img,
is_latent_diffusion = is_latent_diffusion
)
img = vae.decode(img)
@@ -2027,6 +2035,7 @@ class Decoder(BaseGaussianDiffusion):
text = None,
image_embed = None,
text_encodings = None,
text_mask = None,
unet_number = None
):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
@@ -2051,7 +2060,6 @@ class Decoder(BaseGaussianDiffusion):
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
image_embed, _ = self.clip.embed_image(image)
text_encodings = text_mask = None
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings, text_mask = self.clip.embed_text(text)
@@ -2070,12 +2078,14 @@ class Decoder(BaseGaussianDiffusion):
image = aug(image)
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
is_latent_diffusion = not isinstance(vae, NullVQGanVAE)
vae.eval()
with torch.no_grad():
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion)
# main class
@@ -2107,7 +2117,7 @@ class DALLE2(nn.Module):
prior_cond_scale = 1.,
return_pil_images = False
):
device = next(self.parameters()).device
device = module_device(self)
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
if isinstance(text, str) or is_list_str(text):

View File

@@ -0,0 +1,59 @@
from pathlib import Path
import torch
from torch.utils import data
from torchvision import transforms, utils
from PIL import Image
# helpers functions
def cycle(dl):
while True:
for data in dl:
yield data
# dataset and dataloader
class Dataset(data.Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png']
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
self.transform = transforms.Compose([
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(image_size),
transforms.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
def get_images_dataloader(
folder,
*,
batch_size,
image_size,
shuffle = True,
cycle_dl = True,
pin_memory = True
):
ds = Dataset(folder, image_size)
dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
if cycle_dl:
dl = cycle(dl)
return dl

View File

@@ -179,8 +179,8 @@ class EMA(nn.Module):
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.update_after_step = update_after_step # only start EMA after this step number, starting at 0
self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
@@ -189,14 +189,21 @@ class EMA(nn.Module):
device = self.initted.device
self.ema_model.to(device)
def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict())
def update(self):
self.step += 1
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
if (self.step % self.update_every) != 0:
return
if self.step <= self.update_after_step:
self.copy_params_from_model_to_ema()
return
if not self.initted:
self.ema_model.state_dict(self.online_model.state_dict())
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
@@ -405,6 +412,9 @@ class DecoderTrainer(nn.Module):
@torch.no_grad()
@cast_torch_tensor
def sample(self, *args, **kwargs):
if kwargs.pop('use_non_ema', False):
return self.decoder.sample(*args, **kwargs)
if self.use_ema:
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.40',
version = '0.2.44',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',