diff --git a/configs/README.md b/configs/README.md index d473495..be3cb65 100644 --- a/configs/README.md +++ b/configs/README.md @@ -69,6 +69,7 @@ Settings for controlling the training hyperparameters. | `wd` | No | `0.01` | The weight decay. | | `max_grad_norm`| No | `0.5` | The grad norm clipping. | | `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. | +| `cond_scale` | No | `1.0` | Conditioning scale to use for sampling. Can also be an array of values, one for each unet. | | `device` | No | `cuda:0` | The device to train on. | | `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. | | `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. | diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index b23a652..b1f202f 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -75,6 +75,8 @@ def cast_tuple(val, length = None, validate = True): return out def module_device(module): + if isinstance(module, nn.Identity): + return 'cpu' # It doesn't matter return next(module.parameters()).device def zero_init_(m): @@ -2326,7 +2328,7 @@ class Decoder(nn.Module): @property def condition_on_text_encodings(self): - return any([unet.cond_on_text_encodings for unet in self.unets]) + return any([unet.cond_on_text_encodings for unet in self.unets if isinstance(unet, Unet)]) def get_unet(self, unet_number): assert 0 < unet_number <= self.num_unets @@ -2646,11 +2648,13 @@ class Decoder(nn.Module): @eval_decorator def sample( self, + image = None, image_embed = None, text = None, text_encodings = None, batch_size = 1, cond_scale = 1., + start_at_unet_number = 1, stop_at_unet_number = None, distributed = False, inpaint_image = None, @@ -2671,14 +2675,22 @@ class Decoder(nn.Module): assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting' img = None + if start_at_unet_number > 1: + # Then we are not generating the first image and one must have been passed in + assert exists(image), 'image must be passed in if starting at unet number > 1' + assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size) + prev_unet_output_size = self.image_sizes[start_at_unet_number - 2] + img = resize_image_to(image, prev_unet_output_size, nearest = True) is_cuda = next(self.parameters()).is_cuda num_unets = self.num_unets cond_scale = cast_tuple(cond_scale, num_unets) for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)): + if unet_number < start_at_unet_number: + continue # It's the easiest way to do it - context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context() + context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context() with context: # prepare low resolution conditioning for upsamplers diff --git a/dalle2_pytorch/trackers.py b/dalle2_pytorch/trackers.py index 057fbb9..13834c9 100644 --- a/dalle2_pytorch/trackers.py +++ b/dalle2_pytorch/trackers.py @@ -530,11 +530,14 @@ class Tracker: prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior prior: DiffusionPrior = trainer.unwrap_model(prior) # Remove CLIP if it is part of the model + original_clip = prior.clip prior.clip = None model_state_dict = prior.state_dict() + prior.clip = original_clip elif isinstance(trainer, DecoderTrainer): decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder) # Remove CLIP if it is part of the model + original_clip = decoder.clip decoder.clip = None if trainer.use_ema: trainable_unets = decoder.unets @@ -543,6 +546,7 @@ class Tracker: decoder.unets = trainable_unets # Swap back else: model_state_dict = decoder.state_dict() + decoder.clip = original_clip else: raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?') state_dict = { diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index dd30e46..4a0c003 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -306,9 +306,11 @@ class DecoderTrainConfig(BaseModel): max_grad_norm: SingularOrIterable(float) = 0.5 save_every_n_samples: int = 100000 n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset + cond_scale: Union[float, List[float]] = 1.0 device: str = 'cuda:0' epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. validation_samples: int = None # Same as above but for validation. + save_immediately: bool = False use_ema: bool = True ema_beta: float = 0.999 amp: bool = False diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 7fe0a02..e0b3b26 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -498,23 +498,27 @@ class DecoderTrainer(nn.Module): warmup_schedulers = [] for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps): - optimizer = get_optimizer( - unet.parameters(), - lr = unet_lr, - wd = unet_wd, - eps = unet_eps, - group_wd_params = group_wd_params, - **kwargs - ) + if isinstance(unet, nn.Identity): + optimizers.append(None) + schedulers.append(None) + warmup_schedulers.append(None) + else: + optimizer = get_optimizer( + unet.parameters(), + lr = unet_lr, + wd = unet_wd, + eps = unet_eps, + group_wd_params = group_wd_params, + **kwargs + ) - optimizers.append(optimizer) + optimizers.append(optimizer) + scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) - scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) + warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None + warmup_schedulers.append(warmup_scheduler) - warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None - warmup_schedulers.append(warmup_scheduler) - - schedulers.append(scheduler) + schedulers.append(scheduler) if self.use_ema: self.ema_unets.append(EMA(unet, **ema_kwargs)) @@ -590,7 +594,8 @@ class DecoderTrainer(nn.Module): for ind in range(0, self.num_unets): optimizer_key = f'optim{ind}' optimizer = getattr(self, optimizer_key) - save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()} + state_dict = optimizer.state_dict() if optimizer is not None else None + save_obj = {**save_obj, optimizer_key: state_dict} if self.use_ema: save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} @@ -612,8 +617,8 @@ class DecoderTrainer(nn.Module): optimizer_key = f'optim{ind}' optimizer = getattr(self, optimizer_key) warmup_scheduler = self.warmup_schedulers[ind] - - self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key]) + if optimizer is not None: + optimizer.load_state_dict(loaded_obj[optimizer_key]) if exists(warmup_scheduler): warmup_scheduler.last_step = last_step @@ -714,23 +719,32 @@ class DecoderTrainer(nn.Module): *args, unet_number = None, max_batch_size = None, + return_lowres_cond_image=False, **kwargs ): unet_number = self.validate_and_return_unet_number(unet_number) total_loss = 0. - - - using_amp = self.accelerator.mixed_precision != 'no' - + cond_images = [] for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): with self.accelerator.autocast(): - loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs) + loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs) + # loss_obj may be a tuple with loss and cond_image + if return_lowres_cond_image: + loss, cond_image = loss_obj + else: + loss = loss_obj + cond_image = None loss = loss * chunk_size_frac + if cond_image is not None: + cond_images.append(cond_image) total_loss += loss.item() if self.training: self.accelerator.backward(loss) - return total_loss + if return_lowres_cond_image: + return total_loss, torch.stack(cond_images) + else: + return total_loss diff --git a/train_decoder.py b/train_decoder.py index 3c41df7..1da9524 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import List +from datetime import timedelta from dalle2_pytorch.trainer import DecoderTrainer from dalle2_pytorch.dataloaders import create_image_embedding_dataloader @@ -11,11 +12,12 @@ from clip import tokenize import torchvision import torch +from torch import nn from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.image.inception import InceptionScore from torchmetrics.image.kid import KernelInceptionDistance from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity -from accelerate import Accelerator, DistributedDataParallelKwargs +from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs from accelerate.utils import dataclasses as accelerate_dataclasses import webdataset as wds import click @@ -132,7 +134,7 @@ def get_example_data(dataloader, device, n=5): break return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n])) -def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True): +def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True): """ Takes example data and generates images from the embeddings Returns three lists: real images, generated images, and captions @@ -157,6 +159,13 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t # Then we are using precomputed text embeddings text_embeddings = torch.stack(text_embeddings) sample_params["text_encodings"] = text_embeddings + sample_params["start_at_unet_number"] = start_unet + sample_params["stop_at_unet_number"] = end_unet + if start_unet > 1: + # If we are only training upsamplers + sample_params["image"] = torch.stack(real_images) + if device is not None: + sample_params["_device"] = device samples = trainer.sample(**sample_params) generated_images = list(samples) captions = [text_prepend + txt for txt in txts] @@ -165,15 +174,15 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images] return real_images, generated_images, captions -def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""): +def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""): """ Generates samples and uses torchvision to put them in a side by side grid for easy viewing """ - real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend) + real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend) grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] return grid_images, captions -def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): +def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): """ Computes evaluation metrics for the decoder """ @@ -183,7 +192,7 @@ def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=Fa if len(examples) == 0: print("No data to evaluate. Check that your dataloader has shards.") return metrics - real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings) + real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device) real_images = torch.stack(real_images).to(device=device, dtype=torch.float) generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8 @@ -259,11 +268,13 @@ def train( evaluate_config=None, epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch validation_samples = None, + save_immediately=False, epochs = 20, n_sample_images = 5, save_every_n_samples = 100000, unet_training_mask=None, condition_on_text_encodings=False, + cond_scale=1.0, **kwargs ): """ @@ -271,6 +282,21 @@ def train( """ is_master = accelerator.process_index == 0 + if not exists(unet_training_mask): + # Then the unet mask should be true for all unets in the decoder + unet_training_mask = [True] * len(decoder.unets) + assert len(unet_training_mask) == len(decoder.unets), f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}" + trainable_unet_numbers = [i+1 for i, trainable in enumerate(unet_training_mask) if trainable] + first_trainable_unet = trainable_unet_numbers[0] + last_trainable_unet = trainable_unet_numbers[-1] + def move_unets(unet_training_mask): + for i in range(len(decoder.unets)): + if not unet_training_mask[i]: + # Replace the unet from the module list with a nn.Identity(). This training script never uses unets that aren't being trained so this is fine. + decoder.unets[i] = nn.Identity().to(inference_device) + # Remove non-trainable unets + move_unets(unet_training_mask) + trainer = DecoderTrainer( decoder=decoder, accelerator=accelerator, @@ -285,6 +311,7 @@ def train( sample = 0 samples_seen = 0 val_sample = 0 + step = lambda: int(trainer.num_steps_taken(unet_number=first_trainable_unet)) if tracker.can_recall: start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer) @@ -296,13 +323,6 @@ def train( accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}") trainer.to(device=inference_device) - if not exists(unet_training_mask): - # Then the unet mask should be true for all unets in the decoder - unet_training_mask = [True] * trainer.num_unets - first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask) - step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1)) - assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}" - accelerator.print(print_ribbon("Generating Example Data", repeat=40)) accelerator.print("This can take a while to load the shard lists...") if is_master: @@ -360,7 +380,7 @@ def train( tokenized_texts = tokenize(txt, truncate=True) assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" forward_params['text'] = tokenized_texts - loss = trainer.forward(img, **forward_params, unet_number=unet) + loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device) trainer.update(unet_number=unet) unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss @@ -373,10 +393,10 @@ def train( unet_all_losses = accelerator.gather(unet_losses_tensor) mask = unet_all_losses != 0 unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0) - loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 } + loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if unet_training_mask[index] } # gather decay rate on each UNet - ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)} + ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets) if unet_training_mask[index]} log_data = { "Epoch": epoch, @@ -391,7 +411,7 @@ def train( if is_master: tracker.log(log_data, step=step()) - if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope + if is_master and (last_snapshot + save_every_n_samples < sample or (save_immediately and i == 0)): # This will miss by some amount every time, but it's not a big deal... I hope # It is difficult to gather this kind of info on the accelerator, so we have to do it on the master print("Saving snapshot") last_snapshot = sample @@ -399,7 +419,7 @@ def train( save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen) if exists(n_sample_images) and n_sample_images > 0: trainer.eval() - train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ") + train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) if epoch_samples is not None and sample >= epoch_samples: @@ -449,8 +469,9 @@ def train( else: # Then we need to pass the text instead tokenized_texts = tokenize(txt, truncate=True) + assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" forward_params['text'] = tokenized_texts - loss = trainer.forward(img.float(), **forward_params, unet_number=unet) + loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device) average_val_loss_tensor[0, unet-1] += loss if i % VALID_CALC_LOSS_EVERY_ITERS == 0: @@ -477,7 +498,7 @@ def train( if next_task == 'eval': if exists(evaluate_config): accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) - evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings) + evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale) if is_master: tracker.log(evaluation, step=step()) next_task = 'sample' @@ -488,15 +509,15 @@ def train( # Generate examples and save the model if we are the master # Generate sample images print(print_ribbon(f"Sampling Set {epoch}", repeat=40)) - test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ") - train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ") + test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ") + train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) print(print_ribbon(f"Starting Saving {epoch}", repeat=40)) is_best = False if all_average_val_losses is not None: - average_loss = all_average_val_losses.mean(dim=0).item() + average_loss = all_average_val_losses.mean(dim=0).sum() / sum(unet_training_mask) if len(validation_losses) == 0 or average_loss < min(validation_losses): is_best = True validation_losses.append(average_loss) @@ -522,7 +543,8 @@ def initialize_training(config: TrainDecoderConfig, config_path): # Set up accelerator for configurable distributed training ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters) - accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) + init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60)) + accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs]) if accelerator.num_processes > 1: # We are using distributed training and want to immediately ensure all can connect