mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
Merge pull request #234 from Veldrovive/deepspeed_fp16
Fixed issues with clip and deepspeed fp16
This commit is contained in:
@@ -241,7 +241,7 @@ class DecoderConfig(BaseModel):
|
|||||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||||
channels: int = 3
|
channels: int = 3
|
||||||
timesteps: int = 1000
|
timesteps: int = 1000
|
||||||
sample_timesteps: Optional[SingularOrIterable[int]] = None
|
sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
|
||||||
loss_type: str = 'l2'
|
loss_type: str = 'l2'
|
||||||
beta_schedule: ListOrTuple[str] = None # None means all cosine
|
beta_schedule: ListOrTuple[str] = None # None means all cosine
|
||||||
learned_variance: SingularOrIterable[bool] = True
|
learned_variance: SingularOrIterable[bool] = True
|
||||||
|
|||||||
@@ -519,7 +519,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
clip = decoder.clip
|
clip = decoder.clip
|
||||||
clip.to(precision_type)
|
clip.to(precision_type)
|
||||||
|
|
||||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
decoder, train_dataloader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders['train'], *optimizers))
|
||||||
|
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
|
|||||||
break
|
break
|
||||||
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||||
|
|
||||||
def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
|
def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
|
||||||
"""
|
"""
|
||||||
Takes example data and generates images from the embeddings
|
Takes example data and generates images from the embeddings
|
||||||
Returns three lists: real images, generated images, and captions
|
Returns three lists: real images, generated images, and captions
|
||||||
@@ -144,7 +144,9 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
|
|||||||
if img_embeddings[0] is None:
|
if img_embeddings[0] is None:
|
||||||
# Generate image embeddings from clip
|
# Generate image embeddings from clip
|
||||||
imgs_tensor = torch.stack(real_images)
|
imgs_tensor = torch.stack(real_images)
|
||||||
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
|
assert clip is not None, "clip is None, but img_embeddings is None"
|
||||||
|
imgs_tensor.to(device=device)
|
||||||
|
img_embeddings, img_encoding = clip.embed_image(imgs_tensor)
|
||||||
sample_params["image_embed"] = img_embeddings
|
sample_params["image_embed"] = img_embeddings
|
||||||
else:
|
else:
|
||||||
# Then we are using precomputed image embeddings
|
# Then we are using precomputed image embeddings
|
||||||
@@ -153,8 +155,10 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
|
|||||||
if condition_on_text_encodings:
|
if condition_on_text_encodings:
|
||||||
if text_embeddings[0] is None:
|
if text_embeddings[0] is None:
|
||||||
# Generate text embeddings from text
|
# Generate text embeddings from text
|
||||||
|
assert clip is not None, "clip is None, but text_embeddings is None"
|
||||||
tokenized_texts = tokenize(txts, truncate=True)
|
tokenized_texts = tokenize(txts, truncate=True)
|
||||||
sample_params["text"] = tokenized_texts
|
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||||
|
sample_params["text_encodings"] = text_encodings
|
||||||
else:
|
else:
|
||||||
# Then we are using precomputed text embeddings
|
# Then we are using precomputed text embeddings
|
||||||
text_embeddings = torch.stack(text_embeddings)
|
text_embeddings = torch.stack(text_embeddings)
|
||||||
@@ -166,7 +170,7 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
|
|||||||
sample_params["image"] = torch.stack(real_images)
|
sample_params["image"] = torch.stack(real_images)
|
||||||
if device is not None:
|
if device is not None:
|
||||||
sample_params["_device"] = device
|
sample_params["_device"] = device
|
||||||
samples = trainer.sample(**sample_params)
|
samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16
|
||||||
generated_images = list(samples)
|
generated_images = list(samples)
|
||||||
captions = [text_prepend + txt for txt in txts]
|
captions = [text_prepend + txt for txt in txts]
|
||||||
if match_image_size:
|
if match_image_size:
|
||||||
@@ -174,15 +178,15 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
|
|||||||
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
|
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
|
||||||
return real_images, generated_images, captions
|
return real_images, generated_images, captions
|
||||||
|
|
||||||
def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
|
def generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
|
||||||
"""
|
"""
|
||||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||||
"""
|
"""
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
|
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
|
||||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||||
return grid_images, captions
|
return grid_images, captions
|
||||||
|
|
||||||
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||||
"""
|
"""
|
||||||
Computes evaluation metrics for the decoder
|
Computes evaluation metrics for the decoder
|
||||||
"""
|
"""
|
||||||
@@ -192,7 +196,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi
|
|||||||
if len(examples) == 0:
|
if len(examples) == 0:
|
||||||
print("No data to evaluate. Check that your dataloader has shards.")
|
print("No data to evaluate. Check that your dataloader has shards.")
|
||||||
return metrics
|
return metrics
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
|
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
|
||||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||||
@@ -265,6 +269,7 @@ def train(
|
|||||||
accelerator: Accelerator,
|
accelerator: Accelerator,
|
||||||
tracker: Tracker,
|
tracker: Tracker,
|
||||||
inference_device,
|
inference_device,
|
||||||
|
clip=None,
|
||||||
evaluate_config=None,
|
evaluate_config=None,
|
||||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||||
validation_samples = None,
|
validation_samples = None,
|
||||||
@@ -371,15 +376,19 @@ def train(
|
|||||||
forward_params['image_embed'] = img_emb
|
forward_params['image_embed'] = img_emb
|
||||||
else:
|
else:
|
||||||
# Forward pass automatically generates embedding
|
# Forward pass automatically generates embedding
|
||||||
pass
|
assert clip is not None
|
||||||
|
img_embed, img_encoding = clip.embed_image(img)
|
||||||
|
forward_params['image_embed'] = img_embed
|
||||||
if condition_on_text_encodings:
|
if condition_on_text_encodings:
|
||||||
if has_text_embedding:
|
if has_text_embedding:
|
||||||
forward_params['text_encodings'] = text_emb
|
forward_params['text_encodings'] = text_emb
|
||||||
else:
|
else:
|
||||||
# Then we need to pass the text instead
|
# Then we need to pass the text instead
|
||||||
tokenized_texts = tokenize(txt, truncate=True)
|
assert clip is not None
|
||||||
|
tokenized_texts = tokenize(txt, truncate=True).to(inference_device)
|
||||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||||
forward_params['text'] = tokenized_texts
|
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||||
|
forward_params['text_encodings'] = text_encodings
|
||||||
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
|
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
|
||||||
trainer.update(unet_number=unet)
|
trainer.update(unet_number=unet)
|
||||||
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||||
@@ -419,7 +428,7 @@ def train(
|
|||||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
||||||
if exists(n_sample_images) and n_sample_images > 0:
|
if exists(n_sample_images) and n_sample_images > 0:
|
||||||
trainer.eval()
|
trainer.eval()
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||||
|
|
||||||
if epoch_samples is not None and sample >= epoch_samples:
|
if epoch_samples is not None and sample >= epoch_samples:
|
||||||
@@ -462,15 +471,19 @@ def train(
|
|||||||
forward_params['image_embed'] = img_emb.float()
|
forward_params['image_embed'] = img_emb.float()
|
||||||
else:
|
else:
|
||||||
# Forward pass automatically generates embedding
|
# Forward pass automatically generates embedding
|
||||||
pass
|
assert clip is not None
|
||||||
|
img_embed, img_encoding = clip.embed_image(img)
|
||||||
|
forward_params['image_embed'] = img_embed
|
||||||
if condition_on_text_encodings:
|
if condition_on_text_encodings:
|
||||||
if has_text_embedding:
|
if has_text_embedding:
|
||||||
forward_params['text_encodings'] = text_emb.float()
|
forward_params['text_encodings'] = text_emb.float()
|
||||||
else:
|
else:
|
||||||
# Then we need to pass the text instead
|
# Then we need to pass the text instead
|
||||||
|
assert clip is not None
|
||||||
tokenized_texts = tokenize(txt, truncate=True)
|
tokenized_texts = tokenize(txt, truncate=True)
|
||||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||||
forward_params['text'] = tokenized_texts
|
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||||
|
forward_params['text_encodings'] = text_encodings
|
||||||
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
|
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
|
||||||
average_val_loss_tensor[0, unet-1] += loss
|
average_val_loss_tensor[0, unet-1] += loss
|
||||||
|
|
||||||
@@ -498,7 +511,7 @@ def train(
|
|||||||
if next_task == 'eval':
|
if next_task == 'eval':
|
||||||
if exists(evaluate_config):
|
if exists(evaluate_config):
|
||||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
|
||||||
if is_master:
|
if is_master:
|
||||||
tracker.log(evaluation, step=step())
|
tracker.log(evaluation, step=step())
|
||||||
next_task = 'sample'
|
next_task = 'sample'
|
||||||
@@ -509,8 +522,8 @@ def train(
|
|||||||
# Generate examples and save the model if we are the master
|
# Generate examples and save the model if we are the master
|
||||||
# Generate sample images
|
# Generate sample images
|
||||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
|
test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||||
|
|
||||||
@@ -532,6 +545,7 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
|
|||||||
"NumProcesses": accelerator.num_processes,
|
"NumProcesses": accelerator.num_processes,
|
||||||
"MixedPrecision": accelerator.mixed_precision
|
"MixedPrecision": accelerator.mixed_precision
|
||||||
}
|
}
|
||||||
|
accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors
|
||||||
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
||||||
tracker.save_config(config_path, config_name='decoder_config.json')
|
tracker.save_config(config_path, config_name='decoder_config.json')
|
||||||
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
|
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
|
||||||
@@ -556,10 +570,6 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
|||||||
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
|
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
|
||||||
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
|
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
|
||||||
|
|
||||||
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
|
|
||||||
# This is an invalid configuration until we figure out how to handle this
|
|
||||||
raise ValueError("DeepSpeed does not support multi-node distributed training")
|
|
||||||
|
|
||||||
# Set up data
|
# Set up data
|
||||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||||
world_size = accelerator.num_processes
|
world_size = accelerator.num_processes
|
||||||
@@ -579,6 +589,11 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
|||||||
seed = config.seed,
|
seed = config.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If clip is in the model, we need to remove it for compatibility with deepspeed
|
||||||
|
clip = None
|
||||||
|
if config.decoder.clip is not None:
|
||||||
|
clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues
|
||||||
|
config.decoder.clip = None
|
||||||
# Create the decoder model and print basic info
|
# Create the decoder model and print basic info
|
||||||
decoder = config.decoder.create()
|
decoder = config.decoder.create()
|
||||||
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
|
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
|
||||||
@@ -590,7 +605,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
|||||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||||
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
|
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
|
||||||
|
|
||||||
has_clip_model = config.decoder.clip is not None
|
has_clip_model = clip is not None
|
||||||
data_source_string = ""
|
data_source_string = ""
|
||||||
|
|
||||||
if has_img_embeddings:
|
if has_img_embeddings:
|
||||||
@@ -615,6 +630,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
|||||||
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
|
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
|
||||||
|
|
||||||
train(dataloaders, decoder, accelerator,
|
train(dataloaders, decoder, accelerator,
|
||||||
|
clip=clip,
|
||||||
tracker=tracker,
|
tracker=tracker,
|
||||||
inference_device=accelerator.device,
|
inference_device=accelerator.device,
|
||||||
evaluate_config=config.evaluate,
|
evaluate_config=config.evaluate,
|
||||||
|
|||||||
Reference in New Issue
Block a user