mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Merge pull request #234 from Veldrovive/deepspeed_fp16
Fixed issues with clip and deepspeed fp16
This commit is contained in:
@@ -241,7 +241,7 @@ class DecoderConfig(BaseModel):
|
||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
sample_timesteps: Optional[SingularOrIterable[int]] = None
|
||||
sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: ListOrTuple[str] = None # None means all cosine
|
||||
learned_variance: SingularOrIterable[bool] = True
|
||||
|
||||
@@ -519,7 +519,7 @@ class DecoderTrainer(nn.Module):
|
||||
clip = decoder.clip
|
||||
clip.to(precision_type)
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
decoder, train_dataloader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders['train'], *optimizers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
|
||||
@@ -134,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
|
||||
break
|
||||
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||
|
||||
def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
|
||||
def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
|
||||
"""
|
||||
Takes example data and generates images from the embeddings
|
||||
Returns three lists: real images, generated images, and captions
|
||||
@@ -144,7 +144,9 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
|
||||
if img_embeddings[0] is None:
|
||||
# Generate image embeddings from clip
|
||||
imgs_tensor = torch.stack(real_images)
|
||||
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
|
||||
assert clip is not None, "clip is None, but img_embeddings is None"
|
||||
imgs_tensor.to(device=device)
|
||||
img_embeddings, img_encoding = clip.embed_image(imgs_tensor)
|
||||
sample_params["image_embed"] = img_embeddings
|
||||
else:
|
||||
# Then we are using precomputed image embeddings
|
||||
@@ -153,8 +155,10 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
|
||||
if condition_on_text_encodings:
|
||||
if text_embeddings[0] is None:
|
||||
# Generate text embeddings from text
|
||||
assert clip is not None, "clip is None, but text_embeddings is None"
|
||||
tokenized_texts = tokenize(txts, truncate=True)
|
||||
sample_params["text"] = tokenized_texts
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
sample_params["text_encodings"] = text_encodings
|
||||
else:
|
||||
# Then we are using precomputed text embeddings
|
||||
text_embeddings = torch.stack(text_embeddings)
|
||||
@@ -166,7 +170,7 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
|
||||
sample_params["image"] = torch.stack(real_images)
|
||||
if device is not None:
|
||||
sample_params["_device"] = device
|
||||
samples = trainer.sample(**sample_params)
|
||||
samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16
|
||||
generated_images = list(samples)
|
||||
captions = [text_prepend + txt for txt in txts]
|
||||
if match_image_size:
|
||||
@@ -174,15 +178,15 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
|
||||
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
|
||||
return real_images, generated_images, captions
|
||||
|
||||
def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
|
||||
def generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
|
||||
"""
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
"""
|
||||
Computes evaluation metrics for the decoder
|
||||
"""
|
||||
@@ -192,7 +196,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi
|
||||
if len(examples) == 0:
|
||||
print("No data to evaluate. Check that your dataloader has shards.")
|
||||
return metrics
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
|
||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||
@@ -265,6 +269,7 @@ def train(
|
||||
accelerator: Accelerator,
|
||||
tracker: Tracker,
|
||||
inference_device,
|
||||
clip=None,
|
||||
evaluate_config=None,
|
||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||
validation_samples = None,
|
||||
@@ -371,15 +376,19 @@ def train(
|
||||
forward_params['image_embed'] = img_emb
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
assert clip is not None
|
||||
img_embed, img_encoding = clip.embed_image(img)
|
||||
forward_params['image_embed'] = img_embed
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
assert clip is not None
|
||||
tokenized_texts = tokenize(txt, truncate=True).to(inference_device)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
forward_params['text'] = tokenized_texts
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
forward_params['text_encodings'] = text_encodings
|
||||
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
|
||||
trainer.update(unet_number=unet)
|
||||
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||
@@ -419,7 +428,7 @@ def train(
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
if epoch_samples is not None and sample >= epoch_samples:
|
||||
@@ -462,15 +471,19 @@ def train(
|
||||
forward_params['image_embed'] = img_emb.float()
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
assert clip is not None
|
||||
img_embed, img_encoding = clip.embed_image(img)
|
||||
forward_params['image_embed'] = img_embed
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb.float()
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
assert clip is not None
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
forward_params['text'] = tokenized_texts
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
forward_params['text_encodings'] = text_encodings
|
||||
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
|
||||
average_val_loss_tensor[0, unet-1] += loss
|
||||
|
||||
@@ -498,7 +511,7 @@ def train(
|
||||
if next_task == 'eval':
|
||||
if exists(evaluate_config):
|
||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
|
||||
if is_master:
|
||||
tracker.log(evaluation, step=step())
|
||||
next_task = 'sample'
|
||||
@@ -509,8 +522,8 @@ def train(
|
||||
# Generate examples and save the model if we are the master
|
||||
# Generate sample images
|
||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
@@ -532,6 +545,7 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
|
||||
"NumProcesses": accelerator.num_processes,
|
||||
"MixedPrecision": accelerator.mixed_precision
|
||||
}
|
||||
accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors
|
||||
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
||||
tracker.save_config(config_path, config_name='decoder_config.json')
|
||||
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
|
||||
@@ -555,10 +569,6 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
|
||||
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
|
||||
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
|
||||
|
||||
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
|
||||
# This is an invalid configuration until we figure out how to handle this
|
||||
raise ValueError("DeepSpeed does not support multi-node distributed training")
|
||||
|
||||
# Set up data
|
||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||
@@ -579,6 +589,11 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
seed = config.seed,
|
||||
)
|
||||
|
||||
# If clip is in the model, we need to remove it for compatibility with deepspeed
|
||||
clip = None
|
||||
if config.decoder.clip is not None:
|
||||
clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues
|
||||
config.decoder.clip = None
|
||||
# Create the decoder model and print basic info
|
||||
decoder = config.decoder.create()
|
||||
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
|
||||
@@ -590,7 +605,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
|
||||
|
||||
has_clip_model = config.decoder.clip is not None
|
||||
has_clip_model = clip is not None
|
||||
data_source_string = ""
|
||||
|
||||
if has_img_embeddings:
|
||||
@@ -615,6 +630,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
|
||||
|
||||
train(dataloaders, decoder, accelerator,
|
||||
clip=clip,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
evaluate_config=config.evaluate,
|
||||
|
||||
Reference in New Issue
Block a user