From cbaadb693116b214cc29d0569ed0c0a5b9178ab7 Mon Sep 17 00:00:00 2001 From: Aidan Date: Fri, 29 Jul 2022 16:57:27 +0000 Subject: [PATCH] Fixed issues with clip and deepspeed fp16 Also more more general compatibility fixes --- dalle2_pytorch/train_configs.py | 2 +- dalle2_pytorch/trainer.py | 2 +- train_decoder.py | 60 +++++++++++++++++++++------------ 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index ae6407f..6c5bee8 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -241,7 +241,7 @@ class DecoderConfig(BaseModel): clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided channels: int = 3 timesteps: int = 1000 - sample_timesteps: Optional[SingularOrIterable[int]] = None + sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None loss_type: str = 'l2' beta_schedule: ListOrTuple[str] = None # None means all cosine learned_variance: SingularOrIterable[bool] = True diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index a6b6bbc..9b756e0 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -519,7 +519,7 @@ class DecoderTrainer(nn.Module): clip = decoder.clip clip.to(precision_type) - decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) + decoder, train_dataloader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders['train'], *optimizers)) self.decoder = decoder diff --git a/train_decoder.py b/train_decoder.py index 1da9524..41f753d 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -134,7 +134,7 @@ def get_example_data(dataloader, device, n=5): break return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n])) -def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True): +def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True): """ Takes example data and generates images from the embeddings Returns three lists: real images, generated images, and captions @@ -144,7 +144,9 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi if img_embeddings[0] is None: # Generate image embeddings from clip imgs_tensor = torch.stack(real_images) - img_embeddings, *_ = trainer.embed_image(imgs_tensor) + assert clip is not None, "clip is None, but img_embeddings is None" + imgs_tensor.to(device=device) + img_embeddings, img_encoding = clip.embed_image(imgs_tensor) sample_params["image_embed"] = img_embeddings else: # Then we are using precomputed image embeddings @@ -153,8 +155,10 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi if condition_on_text_encodings: if text_embeddings[0] is None: # Generate text embeddings from text + assert clip is not None, "clip is None, but text_embeddings is None" tokenized_texts = tokenize(txts, truncate=True) - sample_params["text"] = tokenized_texts + text_embed, text_encodings = clip.embed_text(tokenized_texts) + sample_params["text_encodings"] = text_encodings else: # Then we are using precomputed text embeddings text_embeddings = torch.stack(text_embeddings) @@ -166,7 +170,7 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi sample_params["image"] = torch.stack(real_images) if device is not None: sample_params["_device"] = device - samples = trainer.sample(**sample_params) + samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16 generated_images = list(samples) captions = [text_prepend + txt for txt in txts] if match_image_size: @@ -174,15 +178,15 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images] return real_images, generated_images, captions -def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""): +def generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""): """ Generates samples and uses torchvision to put them in a side by side grid for easy viewing """ - real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend) + real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend) grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] return grid_images, captions -def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): +def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): """ Computes evaluation metrics for the decoder """ @@ -192,7 +196,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi if len(examples) == 0: print("No data to evaluate. Check that your dataloader has shards.") return metrics - real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device) + real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device) real_images = torch.stack(real_images).to(device=device, dtype=torch.float) generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8 @@ -265,6 +269,7 @@ def train( accelerator: Accelerator, tracker: Tracker, inference_device, + clip=None, evaluate_config=None, epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch validation_samples = None, @@ -371,15 +376,19 @@ def train( forward_params['image_embed'] = img_emb else: # Forward pass automatically generates embedding - pass + assert clip is not None + img_embed, img_encoding = clip.embed_image(img) + forward_params['image_embed'] = img_embed if condition_on_text_encodings: if has_text_embedding: forward_params['text_encodings'] = text_emb else: # Then we need to pass the text instead - tokenized_texts = tokenize(txt, truncate=True) + assert clip is not None + tokenized_texts = tokenize(txt, truncate=True).to(inference_device) assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" - forward_params['text'] = tokenized_texts + text_embed, text_encodings = clip.embed_text(tokenized_texts) + forward_params['text_encodings'] = text_encodings loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device) trainer.update(unet_number=unet) unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss @@ -419,7 +428,7 @@ def train( save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen) if exists(n_sample_images) and n_sample_images > 0: trainer.eval() - train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") + train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) if epoch_samples is not None and sample >= epoch_samples: @@ -462,15 +471,19 @@ def train( forward_params['image_embed'] = img_emb.float() else: # Forward pass automatically generates embedding - pass + assert clip is not None + img_embed, img_encoding = clip.embed_image(img) + forward_params['image_embed'] = img_embed if condition_on_text_encodings: if has_text_embedding: forward_params['text_encodings'] = text_emb.float() else: # Then we need to pass the text instead + assert clip is not None tokenized_texts = tokenize(txt, truncate=True) assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" - forward_params['text'] = tokenized_texts + text_embed, text_encodings = clip.embed_text(tokenized_texts) + forward_params['text_encodings'] = text_encodings loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device) average_val_loss_tensor[0, unet-1] += loss @@ -498,7 +511,7 @@ def train( if next_task == 'eval': if exists(evaluate_config): accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) - evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale) + evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale) if is_master: tracker.log(evaluation, step=step()) next_task = 'sample' @@ -509,8 +522,8 @@ def train( # Generate examples and save the model if we are the master # Generate sample images print(print_ribbon(f"Sampling Set {epoch}", repeat=40)) - test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ") - train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") + test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ") + train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) @@ -532,6 +545,7 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_ "NumProcesses": accelerator.num_processes, "MixedPrecision": accelerator.mixed_precision } + accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy) tracker.save_config(config_path, config_name='decoder_config.json') tracker.add_save_metadata(state_dict_key='config', metadata=config.dict()) @@ -555,10 +569,6 @@ def initialize_training(config: TrainDecoderConfig, config_path): # If we are in deepspeed fp16 mode, we must ensure learned variance is off if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance: raise ValueError("DeepSpeed fp16 mode does not support learned variance") - - if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED: - # This is an invalid configuration until we figure out how to handle this - raise ValueError("DeepSpeed does not support multi-node distributed training") # Set up data all_shards = list(range(config.data.start_shard, config.data.end_shard + 1)) @@ -579,6 +589,11 @@ def initialize_training(config: TrainDecoderConfig, config_path): seed = config.seed, ) + # If clip is in the model, we need to remove it for compatibility with deepspeed + clip = None + if config.decoder.clip is not None: + clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues + config.decoder.clip = None # Create the decoder model and print basic info decoder = config.decoder.create() get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training)) @@ -590,7 +605,7 @@ def initialize_training(config: TrainDecoderConfig, config_path): has_text_embeddings = config.data.text_embeddings_url is not None conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets]) - has_clip_model = config.decoder.clip is not None + has_clip_model = clip is not None data_source_string = "" if has_img_embeddings: @@ -615,6 +630,7 @@ def initialize_training(config: TrainDecoderConfig, config_path): accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training") train(dataloaders, decoder, accelerator, + clip=clip, tracker=tracker, inference_device=accelerator.device, evaluate_config=config.evaluate,