diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 130c023..abf1072 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1870,13 +1870,14 @@ class Decoder(BaseGaussianDiffusion): return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1): + def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False): device = self.betas.device b = shape[0] img = torch.randn(shape, device = device) - lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img) + if not is_latent_diffusion: + lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img) for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): img = self.p_sample( @@ -1896,13 +1897,14 @@ class Decoder(BaseGaussianDiffusion): unnormalize_img = unnormalize_zero_to_one(img) return unnormalize_img - def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False): + def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False): noise = default(noise, lambda: torch.randn_like(x_start)) # normalize to [-1, 1] - x_start = normalize_neg_one_to_one(x_start) - lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img) + if not is_latent_diffusion: + x_start = normalize_neg_one_to_one(x_start) + lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img) # get x_t @@ -2016,7 +2018,8 @@ class Decoder(BaseGaussianDiffusion): predict_x_start = predict_x_start, learned_variance = learned_variance, clip_denoised = not is_latent_diffusion, - lowres_cond_img = lowres_cond_img + lowres_cond_img = lowres_cond_img, + is_latent_diffusion = is_latent_diffusion ) img = vae.decode(img) @@ -2075,12 +2078,14 @@ class Decoder(BaseGaussianDiffusion): image = aug(image) lowres_cond_img = aug(lowres_cond_img, params = aug._params) + is_latent_diffusion = not isinstance(vae, NullVQGanVAE) + vae.eval() with torch.no_grad(): image = vae.encode(image) lowres_cond_img = maybe(vae.encode)(lowres_cond_img) - return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance) + return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion) # main class diff --git a/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py b/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py new file mode 100644 index 0000000..1418c94 --- /dev/null +++ b/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py @@ -0,0 +1,59 @@ +from pathlib import Path + +import torch +from torch.utils import data +from torchvision import transforms, utils + +from PIL import Image + +# helpers functions + +def cycle(dl): + while True: + for data in dl: + yield data + +# dataset and dataloader + +class Dataset(data.Dataset): + def __init__( + self, + folder, + image_size, + exts = ['jpg', 'jpeg', 'png'] + ): + super().__init__() + self.folder = folder + self.image_size = image_size + self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] + + self.transform = transforms.Compose([ + transforms.Resize(image_size), + transforms.RandomHorizontalFlip(), + transforms.CenterCrop(image_size), + transforms.ToTensor() + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + img = Image.open(path) + return self.transform(img) + +def get_images_dataloader( + folder, + *, + batch_size, + image_size, + shuffle = True, + cycle_dl = True, + pin_memory = True +): + ds = Dataset(folder, image_size) + dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory) + + if cycle_dl: + dl = cycle(dl) + return dl diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index e86faff..594e389 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -179,8 +179,8 @@ class EMA(nn.Module): self.online_model = model self.ema_model = copy.deepcopy(model) - self.update_after_step = update_after_step # only start EMA after this step number, starting at 0 self.update_every = update_every + self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0 self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('step', torch.tensor([0.])) @@ -189,6 +189,9 @@ class EMA(nn.Module): device = self.initted.device self.ema_model.to(device) + def copy_params_from_model_to_ema(self): + self.ema_model.state_dict(self.online_model.state_dict()) + def update(self): self.step += 1 @@ -196,7 +199,7 @@ class EMA(nn.Module): return if not self.initted: - self.ema_model.state_dict(self.online_model.state_dict()) + self.copy_params_from_model_to_ema() self.initted.data.copy_(torch.Tensor([True])) self.update_moving_average(self.ema_model, self.online_model) @@ -405,6 +408,9 @@ class DecoderTrainer(nn.Module): @torch.no_grad() @cast_torch_tensor def sample(self, *args, **kwargs): + if kwargs.pop('use_non_ema', False): + return self.decoder.sample(*args, **kwargs) + if self.use_ema: trainable_unets = self.decoder.unets self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling diff --git a/setup.py b/setup.py index 6b7b924..930e90e 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.42', + version = '0.2.43', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',