From f5760bdb92fb8fe9f53aca4a87b3e4b9ef8cdbaa Mon Sep 17 00:00:00 2001 From: Aidan Dempster Date: Sat, 25 Jun 2022 22:05:20 -0400 Subject: [PATCH] Add data flexibility to decoder trainer (#165) * Added the ability to train decoder with text embeddings * Added the ability to train using on the fly generated embeddings with clip * Clip now generates embeddings for whatever is not precomputed --- dalle2_pytorch/dataloaders/decoder_loader.py | 74 ++++++--- dalle2_pytorch/train_configs.py | 41 ++++- dalle2_pytorch/trainer.py | 12 ++ train_decoder.py | 160 +++++++++++++++---- 4 files changed, 228 insertions(+), 59 deletions(-) diff --git a/dalle2_pytorch/dataloaders/decoder_loader.py b/dalle2_pytorch/dataloaders/decoder_loader.py index 5681e2a..572036b 100644 --- a/dalle2_pytorch/dataloaders/decoder_loader.py +++ b/dalle2_pytorch/dataloaders/decoder_loader.py @@ -21,7 +21,7 @@ def get_example_file(fs, path, file_format): """ return fs.glob(os.path.join(path, f"*.{file_format}"))[0] -def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception): +def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception): """Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields""" previous_tar_url = None current_embeddings = None @@ -56,7 +56,7 @@ def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handler # We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop if torch.count_nonzero(embedding) == 0: raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}") - sample["npy"] = embedding + sample[sample_key] = embedding yield sample except Exception as exn: # From wds implementation if handler(exn): @@ -84,18 +84,20 @@ def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.re continue else: break - skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper) -def verify_keys(samples, handler=wds.handlers.reraise_exception): +def join_embeddings(samples, handler=wds.handlers.reraise_exception): """ - Requires that both the image and embedding are present in the sample - This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter. + Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb } + either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist """ for sample in samples: try: - assert "jpg" in sample, f"Sample {sample['__key__']} missing image" - assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?" + sample['emb'] = {} + if 'text_emb' in sample: + sample['emb']['text'] = sample['text_emb'] + if 'img_emb' in sample: + sample['emb']['img'] = sample['img_emb'] yield sample except Exception as exn: # From wds implementation if handler(exn): @@ -103,6 +105,23 @@ def verify_keys(samples, handler=wds.handlers.reraise_exception): else: break +def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception): + """ + Requires that both the image and embedding are present in the sample + This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter. + """ + for sample in samples: + try: + for key in required_keys: + assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}" + yield sample + except Exception as exn: # From wds implementation + if handler(exn): + continue + else: + break +key_verifier = wds.filters.pipelinefilter(verify_keys) + class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface): """ A fluid interface wrapper for DataPipline that returns image embedding pairs @@ -112,7 +131,8 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface): def __init__( self, urls, - embedding_folder_url=None, + img_embedding_folder_url=None, + text_embedding_folder_url=None, index_width=None, img_preproc=None, extra_keys=[], @@ -136,7 +156,12 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface): """ super().__init__() - keys = ["jpg", "npy"] + extra_keys + keys = ["jpg", "emb"] + extra_keys + # if img_embedding_folder_url is not None: + # keys.append("img_emb") + # if text_embedding_folder_url is not None: + # keys.append("text_emb") + # keys.extend(extra_keys) self.key_map = {key: i for i, key in enumerate(keys)} self.resampling = resample self.img_preproc = img_preproc @@ -145,7 +170,7 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface): # Then this has an s3 link for the webdataset and we need extra packages if shutil.which("s3cmd") is None: raise RuntimeError("s3cmd is required for s3 webdataset") - if "s3:" in embedding_folder_url: + if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url): # Then the embeddings are being loaded from s3 and fsspec requires s3fs try: import s3fs @@ -160,17 +185,24 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface): if shuffle_shards: self.append(wds.filters.shuffle(1000)) - if embedding_folder_url is not None: + if img_embedding_folder_url is not None: # There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues. - self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler)) + self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler)) + if text_embedding_folder_url is not None: + self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler)) self.append(wds.tarfile_to_samples(handler=handler)) self.append(wds.decode("pilrgb", handler=handler)) - if embedding_folder_url is not None: - # Then we are loading embeddings for a remote source + if img_embedding_folder_url is not None: + # Then we are loading image embeddings for a remote source assert index_width is not None, "Reading embeddings separately requires index width length to be given" - self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler)) - self.append(verify_keys) + self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler)) + if text_embedding_folder_url is not None: + # Then we are loading image embeddings for a remote source + assert index_width is not None, "Reading embeddings separately requires index width length to be given" + self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler)) + self.append(join_embeddings) + self.append(key_verifier(required_keys=keys, handler=handler)) # Apply preprocessing self.append(wds.map(self.preproc)) self.append(wds.to_tuple(*keys)) @@ -185,7 +217,8 @@ def create_image_embedding_dataloader( tar_url, num_workers, batch_size, - embeddings_url=None, + img_embeddings_url=None, + text_embeddings_url=None, index_width=None, shuffle_num = None, shuffle_shards = True, @@ -211,7 +244,8 @@ def create_image_embedding_dataloader( """ ds = ImageEmbeddingDataset( tar_url, - embeddings_url, + img_embedding_folder_url=img_embeddings_url, + text_embedding_folder_url=text_embeddings_url, index_width=index_width, shuffle_shards=shuffle_shards, resample=resample_shards, @@ -228,4 +262,4 @@ def create_image_embedding_dataloader( prefetch_factor=2, # This might be good to have high so the next npy file is prefetched pin_memory=True, shuffle=False - ) \ No newline at end of file + ) diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 80bc0cd..08f93dc 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -13,7 +13,7 @@ from dalle2_pytorch.dalle2_pytorch import ( Decoder, DiffusionPrior, DiffusionPriorNetwork, - XClipAdapter, + XClipAdapter ) # helper functions @@ -170,6 +170,8 @@ class DecoderConfig(BaseModel): unets: ListOrTuple(UnetConfig) image_size: int = None image_sizes: ListOrTuple(int) = None + condition_on_text_encodings: bool = False + clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided channels: int = 3 timesteps: int = 1000 loss_type: str = 'l2' @@ -180,9 +182,16 @@ class DecoderConfig(BaseModel): def create(self): decoder_kwargs = self.dict() + unet_configs = decoder_kwargs.pop('unets') unets = [Unet(**config) for config in unet_configs] - return Decoder(unets, **decoder_kwargs) + + has_clip = exists(decoder_kwargs.pop('clip')) + clip = None + if has_clip: + clip = self.clip.create() + + return Decoder(unets, clip=clip, **decoder_kwargs) @validator('image_sizes') def check_image_sizes(cls, image_sizes, values): @@ -194,8 +203,9 @@ class DecoderConfig(BaseModel): extra = "allow" class DecoderDataConfig(BaseModel): - webdataset_base_url: str # path to a webdataset with jpg images - embeddings_url: str # path to .npy files with embeddings + webdataset_base_url: str # path to a webdataset with jpg images + img_embeddings_url: Optional[str] # path to .npy files with embeddings + text_embeddings_url: Optional[str] # path to .npy files with embeddings num_workers: int = 4 batch_size: int = 64 start_shard: int = 0 @@ -268,3 +278,26 @@ class TrainDecoderConfig(BaseModel): with open(json_path) as f: config = json.load(f) return cls(**config) + + @root_validator + def check_has_embeddings(cls, values): + # Makes sure that enough information is provided to get the embeddings specified for training + data_config, decoder_config = values.get('data'), values.get('decoder') + if data_config is None or decoder_config is None: + # Then something else errored and we should just pass through + return values + using_text_embeddings = decoder_config.condition_on_text_encodings + using_clip = exists(decoder_config.clip) + img_emb_url = data_config.img_embeddings_url + text_emb_url = data_config.text_embeddings_url + if using_text_embeddings: + # Then we need some way to get the embeddings + assert using_clip or text_emb_url is not None, 'If condition_on_text_encodings is true, either clip or text_embeddings_url must be provided' + if using_clip: + if using_text_embeddings: + assert text_emb_url is None or img_emb_url is None, 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the embeddings' + else: + assert img_emb_url is None, 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings' + if text_emb_url: + assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason." + return values diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index b02995e..420e5e9 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -578,6 +578,18 @@ class DecoderTrainer(nn.Module): return output + @torch.no_grad() + @cast_torch_tensor + @prior_sample_in_chunks + def embed_text(self, *args, **kwargs): + return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs) + + @torch.no_grad() + @cast_torch_tensor + @prior_sample_in_chunks + def embed_image(self, *args, **kwargs): + return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs) + @cast_torch_tensor def forward( self, diff --git a/train_decoder.py b/train_decoder.py index 832391b..3fe1289 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -6,6 +6,7 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker from dalle2_pytorch.train_configs import TrainDecoderConfig from dalle2_pytorch.utils import Timer, print_ribbon from dalle2_pytorch.dalle2_pytorch import resize_image_to +from clip import tokenize import torchvision import torch @@ -33,7 +34,8 @@ def exists(val): def create_dataloaders( available_shards, webdataset_base_url, - embeddings_url, + img_embeddings_url=None, + text_embeddings_url=None, shard_width=6, num_workers=4, batch_size=32, @@ -63,14 +65,15 @@ def create_dataloaders( test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split] val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split] - create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader( + create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader( tar_url=tar_urls, num_workers=num_workers, batch_size=batch_size if not for_sampling else n_sample_images, - embeddings_url=embeddings_url, + img_embeddings_url=img_embeddings_url, + text_embeddings_url=text_embeddings_url, index_width=index_width, shuffle_num = None, - extra_keys= ["txt"] if with_text else [], + extra_keys= ["txt"], shuffle_shards = shuffle, resample_shards = resample, img_preproc=img_preproc, @@ -79,8 +82,8 @@ def create_dataloaders( train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train) train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True) - val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True) - test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True) + val_dataloader = create_dataloader(val_urls, shuffle=False) + test_dataloader = create_dataloader(test_urls, shuffle=False) test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True) return { "train": train_dataloader, @@ -104,42 +107,65 @@ def get_example_data(dataloader, device, n=5): Samples the dataloader and returns a zipped list of examples """ images = [] - embeddings = [] + img_embeddings = [] + text_embeddings = [] captions = [] - dataset_keys = get_dataset_keys(dataloader) - has_caption = "txt" in dataset_keys - for data in dataloader: - if has_caption: - img, emb, txt = data + for img, emb, txt in dataloader: + img_emb, text_emb = emb.get('img'), emb.get('text') + if img_emb is not None: + img_emb = img_emb.to(device=device, dtype=torch.float) + img_embeddings.extend(list(img_emb)) else: - img, emb = data - txt = [""] * emb.shape[0] + # Then we add None img.shape[0] times + img_embeddings.extend([None]*img.shape[0]) + if text_emb is not None: + text_emb = text_emb.to(device=device, dtype=torch.float) + text_embeddings.extend(list(text_emb)) + else: + # Then we add None img.shape[0] times + text_embeddings.extend([None]*img.shape[0]) img = img.to(device=device, dtype=torch.float) - emb = emb.to(device=device, dtype=torch.float) images.extend(list(img)) - embeddings.extend(list(emb)) captions.extend(list(txt)) if len(images) >= n: break - return list(zip(images[:n], embeddings[:n], captions[:n])) + return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n])) -def generate_samples(trainer, example_data, text_prepend=""): +def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""): """ Takes example data and generates images from the embeddings Returns three lists: real images, generated images, and captions """ - real_images, embeddings, txts = zip(*example_data) - embeddings_tensor = torch.stack(embeddings) - samples = trainer.sample(embeddings_tensor) + real_images, img_embeddings, text_embeddings, txts = zip(*example_data) + sample_params = {} + if img_embeddings[0] is None: + # Generate image embeddings from clip + imgs_tensor = torch.stack(real_images) + img_embeddings, *_ = trainer.embed_image(imgs_tensor) + sample_params["image_embed"] = img_embeddings + else: + # Then we are using precomputed image embeddings + img_embeddings = torch.stack(img_embeddings) + sample_params["image_embed"] = img_embeddings + if condition_on_text_encodings: + if text_embeddings[0] is None: + # Generate text embeddings from text + tokenized_texts = tokenize(txts, truncate=True) + sample_params["text"] = tokenized_texts + else: + # Then we are using precomputed text embeddings + text_embeddings = torch.stack(text_embeddings) + sample_params["text_encodings"] = text_embeddings + samples = trainer.sample(**sample_params) generated_images = list(samples) captions = [text_prepend + txt for txt in txts] return real_images, generated_images, captions -def generate_grid_samples(trainer, examples, text_prepend=""): +def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""): """ Generates samples and uses torchvision to put them in a side by side grid for easy viewing """ - real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend) + real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend) real_image_size = real_images[0].shape[-1] generated_image_size = generated_images[0].shape[-1] @@ -151,7 +177,7 @@ def generate_grid_samples(trainer, examples, text_prepend=""): grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] return grid_images, captions -def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): +def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): """ Computes evaluation metrics for the decoder """ @@ -161,7 +187,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID if len(examples) == 0: print("No data to evaluate. Check that your dataloader has shards.") return metrics - real_images, generated_images, captions = generate_samples(trainer, examples) + real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings) real_images = torch.stack(real_images).to(device=device, dtype=torch.float) generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8 @@ -250,6 +276,7 @@ def train( save_latest=True, save_best=True, unet_training_mask=None, + condition_on_text_encodings=False, **kwargs ): """ @@ -307,14 +334,22 @@ def train( last_snapshot = sample if next_task == 'train': - for i, (img, emb) in enumerate(dataloaders["train"]): + for i, (img, emb, txt) in enumerate(dataloaders["train"]): # We want to count the total number of samples across all processes sample_length_tensor[0] = len(img) all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this. total_samples = all_samples.sum().item() sample += total_samples samples_seen += total_samples - img, emb = send_to_device((img, emb)) + img_emb = emb.get('img') + has_img_embedding = img_emb is not None + if has_img_embedding: + img_emb, = send_to_device((img_emb,)) + text_emb = emb.get('text') + has_text_embedding = text_emb is not None + if has_text_embedding: + text_emb, = send_to_device((text_emb,)) + img, = send_to_device((img,)) trainer.train() for unet in range(1, trainer.num_unets+1): @@ -322,7 +357,20 @@ def train( if not unet_training_mask[unet-1]: # Unet index is the unet number - 1 continue - loss = trainer.forward(img, image_embed=emb, unet_number=unet) + forward_params = {} + if has_img_embedding: + forward_params['image_embed'] = img_emb + else: + # Forward pass automatically generates embedding + pass + if condition_on_text_encodings: + if has_text_embedding: + forward_params['text_encodings'] = text_emb + else: + # Then we need to pass the text instead + tokenized_texts = tokenize(txt, truncate=True) + forward_params['text'] = tokenized_texts + loss = trainer.forward(img, **forward_params, unet_number=unet) trainer.update(unet_number=unet) unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss @@ -366,7 +414,7 @@ def train( save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths) if exists(n_sample_images) and n_sample_images > 0: trainer.eval() - train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ") + train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ") tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) if epoch_samples is not None and sample >= epoch_samples: @@ -389,14 +437,35 @@ def train( all_samples = accelerator.gather(val_sample_length_tensor) total_samples = all_samples.sum().item() val_sample += total_samples - img, emb = send_to_device((img, emb)) + img_emb = emb.get('img') + has_img_embedding = img_emb is not None + if has_img_embedding: + img_emb, = send_to_device((img_emb,)) + text_emb = emb.get('text') + has_text_embedding = text_emb is not None + if has_text_embedding: + text_emb, = send_to_device((text_emb,)) + img, = send_to_device((img,)) for unet in range(1, len(decoder.unets)+1): if not unet_training_mask[unet-1]: # Unet index is the unet number - 1 # No need to evaluate an unchanging unet continue - - loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet) + + forward_params = {} + if has_img_embedding: + forward_params['image_embed'] = img_emb.float() + else: + # Forward pass automatically generates embedding + pass + if condition_on_text_encodings: + if has_text_embedding: + forward_params['text_encodings'] = text_emb.float() + else: + # Then we need to pass the text instead + tokenized_texts = tokenize(txt, truncate=True) + forward_params['text'] = tokenized_texts + loss = trainer.forward(img.float(), **forward_params, unet_number=unet) average_val_loss_tensor[0, unet-1] += loss if i % VALID_CALC_LOSS_EVERY_ITERS == 0: @@ -423,7 +492,7 @@ def train( if next_task == 'eval': if exists(evaluate_config): accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) - evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict()) + evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings) if is_master: tracker.log(evaluation, step=step(), verbose=True) next_task = 'sample' @@ -434,8 +503,8 @@ def train( # Generate examples and save the model if we are the master # Generate sample images print(print_ribbon(f"Sampling Set {epoch}", repeat=40)) - test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ") - train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ") + test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ") + train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ") tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) @@ -525,14 +594,35 @@ def initialize_training(config, config_path): # Create and initialize the tracker if we are the master tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy") + has_img_embeddings = config.data.img_embeddings_url is not None + has_text_embeddings = config.data.text_embeddings_url is not None + conditioning_on_text = config.decoder.condition_on_text_encodings + has_clip_model = config.decoder.clip is not None + data_source_string = "" + if has_img_embeddings: + data_source_string += "precomputed image embeddings" + elif has_clip_model: + data_source_string += "clip image embeddings generation" + else: + raise ValueError("No image embeddings source specified") + if conditioning_on_text: + if has_text_embeddings: + data_source_string += " and precomputed text embeddings" + elif has_clip_model: + data_source_string += " and clip text encoding generation" + else: + raise ValueError("No text embeddings source specified") + accelerator.print(print_ribbon("Loaded Config", repeat=40)) accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training") + accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}") accelerator.print(f"Number of parameters: {num_parameters}") train(dataloaders, decoder, accelerator, tracker=tracker, inference_device=accelerator.device, load_config=config.load, evaluate_config=config.evaluate, + condition_on_text_encodings=config.decoder.condition_on_text_encodings, **config.train.dict(), )