mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Distributed Training of the Decoder (#121)
* Converted decoder trainer to use accelerate * Fixed issue where metric evaluation would hang on distributed mode * Implemented functional saving Loading still fails due to some issue with the optimizer * Fixed issue with loading decoders * Fixed issue with tracker config * Fixed issue with amp Updated logging to be more logical * Saving checkpoint now saves position in training as well Fixed an issue with running out of gpu space due to loading weights into the gpu twice * Fixed ema for distributed training * Fixed isue where get_pkg_version was reintroduced * Changed decoder trainer to upload config as a file Fixed issue where loading best would error
This commit is contained in:
395
train_decoder.py
395
train_decoder.py
@@ -1,7 +1,8 @@
|
||||
from dalle2_pytorch import Unet, Decoder
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch.trainer import DecoderTrainer
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||
@@ -12,6 +13,8 @@ from torchmetrics.image.fid import FrechetInceptionDistance
|
||||
from torchmetrics.image.inception import InceptionScore
|
||||
from torchmetrics.image.kid import KernelInceptionDistance
|
||||
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from accelerate.utils import dataclasses as accelerate_dataclasses
|
||||
import webdataset as wds
|
||||
import click
|
||||
|
||||
@@ -42,6 +45,7 @@ def create_dataloaders(
|
||||
train_prop = 0.75,
|
||||
val_prop = 0.15,
|
||||
test_prop = 0.10,
|
||||
seed = 0,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@@ -52,7 +56,7 @@ def create_dataloaders(
|
||||
num_test = round(test_prop*len(available_shards))
|
||||
num_val = len(available_shards) - num_train - num_test
|
||||
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
|
||||
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0))
|
||||
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(seed))
|
||||
|
||||
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
|
||||
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
|
||||
@@ -117,7 +121,6 @@ def get_example_data(dataloader, device, n=5):
|
||||
captions.extend(list(txt))
|
||||
if len(images) >= n:
|
||||
break
|
||||
print("Generated {} examples".format(len(images)))
|
||||
return list(zip(images[:n], embeddings[:n], captions[:n]))
|
||||
|
||||
def generate_samples(trainer, example_data, text_prepend=""):
|
||||
@@ -155,27 +158,34 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
|
||||
metrics = {}
|
||||
# Prepare the data
|
||||
examples = get_example_data(dataloader, device, n_evaluation_samples)
|
||||
if len(examples) == 0:
|
||||
print("No data to evaluate. Check that your dataloader has shards.")
|
||||
return metrics
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples)
|
||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||
|
||||
def null_sync(t, *args, **kwargs):
|
||||
return [t]
|
||||
|
||||
if exists(FID):
|
||||
fid = FrechetInceptionDistance(**FID)
|
||||
fid = FrechetInceptionDistance(**FID, dist_sync_fn=null_sync)
|
||||
fid.to(device=device)
|
||||
fid.update(int_real_images, real=True)
|
||||
fid.update(int_generated_images, real=False)
|
||||
metrics["FID"] = fid.compute().item()
|
||||
if exists(IS):
|
||||
inception = InceptionScore(**IS)
|
||||
inception = InceptionScore(**IS, dist_sync_fn=null_sync)
|
||||
inception.to(device=device)
|
||||
inception.update(int_real_images)
|
||||
is_mean, is_std = inception.compute()
|
||||
metrics["IS_mean"] = is_mean.item()
|
||||
metrics["IS_std"] = is_std.item()
|
||||
if exists(KID):
|
||||
kernel_inception = KernelInceptionDistance(**KID)
|
||||
kernel_inception = KernelInceptionDistance(**KID, dist_sync_fn=null_sync)
|
||||
kernel_inception.to(device=device)
|
||||
kernel_inception.update(int_real_images, real=True)
|
||||
kernel_inception.update(int_generated_images, real=False)
|
||||
@@ -186,39 +196,47 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
|
||||
# Convert from [0, 1] to [-1, 1]
|
||||
renorm_real_images = real_images.mul(2).sub(1)
|
||||
renorm_generated_images = generated_images.mul(2).sub(1)
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS)
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
|
||||
lpips.to(device=device)
|
||||
lpips.update(renorm_real_images, renorm_generated_images)
|
||||
metrics["LPIPS"] = lpips.compute().item()
|
||||
|
||||
if trainer.accelerator.num_processes > 1:
|
||||
# Then we should sync the metrics
|
||||
metrics_order = sorted(metrics.keys())
|
||||
metrics_tensor = torch.zeros(1, len(metrics), device=device, dtype=torch.float)
|
||||
for i, metric_name in enumerate(metrics_order):
|
||||
metrics_tensor[0, i] = metrics[metric_name]
|
||||
metrics_tensor = trainer.accelerator.gather(metrics_tensor)
|
||||
metrics_tensor = metrics_tensor.mean(dim=0)
|
||||
for i, metric_name in enumerate(metrics_order):
|
||||
metrics[metric_name] = metrics_tensor[i].item()
|
||||
return metrics
|
||||
|
||||
def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths):
|
||||
def save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, relative_paths):
|
||||
"""
|
||||
Logs the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
if isinstance(relative_paths, str):
|
||||
relative_paths = [relative_paths]
|
||||
trainer_state_dict = {}
|
||||
trainer_state_dict["trainer"] = trainer.state_dict()
|
||||
trainer_state_dict['epoch'] = epoch
|
||||
trainer_state_dict['step'] = step
|
||||
trainer_state_dict['validation_losses'] = validation_losses
|
||||
for relative_path in relative_paths:
|
||||
tracker.save_state_dict(trainer_state_dict, relative_path)
|
||||
local_path = str(tracker.data_path / relative_path)
|
||||
trainer.save(local_path, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses)
|
||||
tracker.save_file(local_path)
|
||||
|
||||
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
||||
"""
|
||||
Loads the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
print(print_ribbon(f"Loading model from {recall_source}"))
|
||||
state_dict = tracker.recall_state_dict(recall_source, **load_config.dict())
|
||||
trainer.load_state_dict(state_dict["trainer"])
|
||||
print("Model loaded")
|
||||
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
||||
trainer.accelerator.print(print_ribbon(f"Loading model from {recall_source}"))
|
||||
local_filepath = tracker.recall_file(recall_source, **load_config)
|
||||
state_dict = trainer.load(local_filepath)
|
||||
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0)
|
||||
|
||||
def train(
|
||||
dataloaders,
|
||||
decoder,
|
||||
accelerator,
|
||||
tracker,
|
||||
inference_device,
|
||||
load_config=None,
|
||||
@@ -237,17 +255,30 @@ def train(
|
||||
"""
|
||||
Trains a decoder on a dataset.
|
||||
"""
|
||||
trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here
|
||||
is_master = accelerator.process_index == 0
|
||||
|
||||
trainer = DecoderTrainer(
|
||||
accelerator,
|
||||
decoder,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Set up starting model and parameters based on a recalled state dict
|
||||
start_step = 0
|
||||
start_epoch = 0
|
||||
validation_losses = []
|
||||
next_task = 'train'
|
||||
sample = 0
|
||||
val_sample = 0
|
||||
step = lambda: int(trainer.step.item())
|
||||
|
||||
if exists(load_config) and exists(load_config.source):
|
||||
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
|
||||
start_epoch, validation_losses, next_task, recalled_sample = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config.dict())
|
||||
if next_task == 'train':
|
||||
sample = recalled_sample
|
||||
if next_task == 'val':
|
||||
val_sample = recalled_sample
|
||||
accelerator.print(f"Loaded model from {load_config.source} on epoch {start_epoch} with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
||||
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
|
||||
trainer.to(device=inference_device)
|
||||
|
||||
if not exists(unet_training_mask):
|
||||
@@ -255,139 +286,185 @@ def train(
|
||||
unet_training_mask = [True] * trainer.num_unets
|
||||
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||
|
||||
print(print_ribbon("Generating Example Data", repeat=40))
|
||||
print("This can take a while to load the shard lists...")
|
||||
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
|
||||
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
|
||||
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
|
||||
accelerator.print("This can take a while to load the shard lists...")
|
||||
if is_master:
|
||||
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
|
||||
accelerator.print("Generated training examples")
|
||||
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
|
||||
accelerator.print("Generated testing examples")
|
||||
|
||||
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
|
||||
step = start_step
|
||||
|
||||
sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
|
||||
unet_losses_tensor = torch.zeros(TRAIN_CALC_LOSS_EVERY_ITERS, trainer.num_unets, dtype=torch.float, device=inference_device)
|
||||
for epoch in range(start_epoch, epochs):
|
||||
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
||||
accelerator.print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
||||
|
||||
timer = Timer()
|
||||
last_sample = sample
|
||||
last_snapshot = sample
|
||||
|
||||
sample = 0
|
||||
last_sample = 0
|
||||
last_snapshot = 0
|
||||
if next_task == 'train':
|
||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
||||
# We want to count the total number of samples across all processes
|
||||
sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||
total_samples = all_samples.sum().item()
|
||||
sample += total_samples
|
||||
img, emb = send_to_device((img, emb))
|
||||
|
||||
losses = []
|
||||
trainer.train()
|
||||
for unet in range(1, trainer.num_unets+1):
|
||||
# Check if this is a unet we are training
|
||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
continue
|
||||
|
||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
||||
step += 1
|
||||
sample += img.shape[0]
|
||||
img, emb = send_to_device((img, emb))
|
||||
|
||||
trainer.train()
|
||||
for unet in range(1, trainer.num_unets+1):
|
||||
# Check if this is a unet we are training
|
||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
continue
|
||||
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||
|
||||
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
||||
timer.reset()
|
||||
last_sample = sample
|
||||
|
||||
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
losses.append(loss)
|
||||
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
|
||||
# We want to average losses across all processes
|
||||
unet_all_losses = accelerator.gather(unet_losses_tensor)
|
||||
mask = unet_all_losses != 0
|
||||
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
|
||||
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
|
||||
log_data = {
|
||||
"Epoch": epoch,
|
||||
"Sample": sample,
|
||||
"Step": i,
|
||||
"Samples per second": samples_per_sec,
|
||||
**loss_map
|
||||
}
|
||||
# print(f"I am rank {accelerator.state.process_index}. Example weight: {trainer.decoder.state_dict()['module.unets.0.init_conv.convs.0.weight'][0,0,0,0]}")
|
||||
if is_master:
|
||||
tracker.log(log_data, step=step(), verbose=True)
|
||||
|
||||
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
||||
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
||||
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
|
||||
print("Saving snapshot")
|
||||
last_snapshot = sample
|
||||
# We need to know where the model should be saved
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
if save_all:
|
||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step()}.pth")
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
if epoch_samples is not None and sample >= epoch_samples:
|
||||
break
|
||||
next_task = 'val'
|
||||
sample = 0
|
||||
|
||||
timer.reset()
|
||||
last_sample = sample
|
||||
all_average_val_losses = None
|
||||
if next_task == 'val':
|
||||
trainer.eval()
|
||||
accelerator.print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
|
||||
last_val_sample = val_sample
|
||||
val_sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
|
||||
average_val_loss_tensor = torch.zeros(1, trainer.num_unets, dtype=torch.float, device=inference_device)
|
||||
timer = Timer()
|
||||
accelerator.wait_for_everyone()
|
||||
i = 0
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||
val_sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
val_sample += total_samples
|
||||
img, emb = send_to_device((img, emb))
|
||||
|
||||
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
|
||||
average_loss = sum(losses) / len(losses)
|
||||
log_data = {
|
||||
"Training loss": average_loss,
|
||||
"Epoch": epoch,
|
||||
"Sample": sample,
|
||||
"Step": i,
|
||||
"Samples per second": samples_per_sec
|
||||
}
|
||||
tracker.log(log_data, step=step, verbose=True)
|
||||
losses = []
|
||||
for unet in range(1, len(decoder.unets)+1):
|
||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
# No need to evaluate an unchanging unet
|
||||
continue
|
||||
|
||||
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
||||
average_val_loss_tensor[0, unet-1] += loss
|
||||
|
||||
if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
||||
last_snapshot = sample
|
||||
# We need to know where the model should be saved
|
||||
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||
samples_per_sec = (val_sample - last_val_sample) / timer.elapsed()
|
||||
timer.reset()
|
||||
last_val_sample = val_sample
|
||||
accelerator.print(f"Epoch {epoch}/{epochs} Val Step {i} - Sample {val_sample} - {samples_per_sec:.2f} samples/sec")
|
||||
accelerator.print(f"Loss: {(average_val_loss_tensor / (i+1))}")
|
||||
accelerator.print("")
|
||||
|
||||
if validation_samples is not None and val_sample >= validation_samples:
|
||||
break
|
||||
print(f"Rank {accelerator.state.process_index} finished validation after {i} steps")
|
||||
accelerator.wait_for_everyone()
|
||||
average_val_loss_tensor /= i+1
|
||||
# Gather all the average loss tensors
|
||||
all_average_val_losses = accelerator.gather(average_val_loss_tensor)
|
||||
if is_master:
|
||||
unet_average_val_loss = all_average_val_losses.mean(dim=0)
|
||||
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
|
||||
tracker.log(val_loss_map, step=step(), verbose=True)
|
||||
next_task = 'eval'
|
||||
|
||||
if next_task == 'eval':
|
||||
if exists(evaluate_config):
|
||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
|
||||
if is_master:
|
||||
tracker.log(evaluation, step=step(), verbose=True)
|
||||
next_task = 'sample'
|
||||
val_sample = 0
|
||||
|
||||
if next_task == 'sample':
|
||||
if is_master:
|
||||
# Generate examples and save the model if we are the master
|
||||
# Generate sample images
|
||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||
# Get the same paths
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
if save_all:
|
||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
||||
if all_average_val_losses is not None:
|
||||
average_loss = all_average_val_losses.mean(dim=0).item()
|
||||
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
||||
save_paths.append("best.pth")
|
||||
validation_losses.append(average_loss)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
next_task = 'train'
|
||||
|
||||
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
||||
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
||||
|
||||
if exists(epoch_samples) and sample >= epoch_samples:
|
||||
break
|
||||
|
||||
trainer.eval()
|
||||
print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
|
||||
with torch.no_grad():
|
||||
sample = 0
|
||||
average_loss = 0
|
||||
timer = Timer()
|
||||
for i, (img, emb, *_) in enumerate(dataloaders["val"]):
|
||||
sample += img.shape[0]
|
||||
img, emb = send_to_device((img, emb))
|
||||
|
||||
for unet in range(1, len(decoder.unets)+1):
|
||||
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
||||
average_loss += loss
|
||||
|
||||
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
|
||||
print(f"Loss: {average_loss / (i+1)}")
|
||||
print("")
|
||||
|
||||
if exists(validation_samples) and sample >= validation_samples:
|
||||
break
|
||||
|
||||
average_loss /= i+1
|
||||
log_data = {
|
||||
"Validation loss": average_loss
|
||||
}
|
||||
tracker.log(log_data, step=step, verbose=True)
|
||||
|
||||
# Compute evaluation metrics
|
||||
if exists(evaluate_config):
|
||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
|
||||
tracker.log(evaluation, step=step, verbose=True)
|
||||
|
||||
# Generate sample images
|
||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step)
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
||||
|
||||
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||
# Get the same paths
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
||||
save_paths.append("best.pth")
|
||||
validation_losses.append(average_loss)
|
||||
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
||||
|
||||
def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
|
||||
def create_tracker(accelerator, config, config_path, tracker_type=None, data_path=None):
|
||||
"""
|
||||
Creates a tracker of the specified type and initializes special features based on the full config
|
||||
"""
|
||||
tracker_config = config.tracker
|
||||
init_config = {}
|
||||
accelerator_config = {
|
||||
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
|
||||
"DistributedType": accelerator.distributed_type,
|
||||
"NumProcesses": accelerator.num_processes,
|
||||
"MixedPrecision": accelerator.mixed_precision
|
||||
}
|
||||
init_config = { "config": {**config.dict(), **accelerator_config} }
|
||||
data_path = data_path or tracker_config.data_path
|
||||
tracker_type = tracker_type or tracker_config.tracker_type
|
||||
|
||||
if exists(tracker_config.init_config):
|
||||
init_config["config"] = tracker_config.init_config
|
||||
|
||||
if tracker_type == "console":
|
||||
tracker = ConsoleTracker(**init_config)
|
||||
if tracker_type == "dummy":
|
||||
tracker = DummyTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
elif tracker_type == "console":
|
||||
tracker = ConsoleTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
elif tracker_type == "wandb":
|
||||
# We need to initialize the resume state here
|
||||
load_config = config.load
|
||||
@@ -401,51 +478,63 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
|
||||
init_config["project"] = tracker_config.wandb_project
|
||||
tracker = WandbTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
tracker.save_file(str(config_path.absolute()), str(config_path.parent.absolute()))
|
||||
else:
|
||||
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
|
||||
return tracker
|
||||
|
||||
def initialize_training(config):
|
||||
# Create the save path
|
||||
if "cuda" in config.train.device:
|
||||
assert torch.cuda.is_available(), "CUDA is not available"
|
||||
device = torch.device(config.train.device)
|
||||
torch.cuda.set_device(device)
|
||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||
def initialize_training(config, config_path):
|
||||
# Make sure if we are not loading, distributed models are initialized to the same values
|
||||
torch.manual_seed(config.seed)
|
||||
|
||||
# Set up accelerator for configurable distributed training
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
# Set up data
|
||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||
world_size = accelerator.num_processes
|
||||
rank = accelerator.process_index
|
||||
shards_per_process = len(all_shards) // world_size
|
||||
assert shards_per_process > 0, "Not enough shards to split evenly"
|
||||
my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]
|
||||
dataloaders = create_dataloaders (
|
||||
available_shards=all_shards,
|
||||
available_shards=my_shards,
|
||||
img_preproc = config.data.img_preproc,
|
||||
train_prop = config.data.splits.train,
|
||||
val_prop = config.data.splits.val,
|
||||
test_prop = config.data.splits.test,
|
||||
n_sample_images=config.train.n_sample_images,
|
||||
**config.data.dict()
|
||||
**config.data.dict(),
|
||||
rank = rank,
|
||||
seed = config.seed,
|
||||
)
|
||||
|
||||
decoder = config.decoder.create().to(device = device)
|
||||
# Create the decoder model and print basic info
|
||||
decoder = config.decoder.create()
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
print(print_ribbon("Loaded Config", repeat=40))
|
||||
print(f"Number of parameters: {num_parameters}")
|
||||
|
||||
tracker = create_tracker(config, **config.tracker.dict())
|
||||
# Create and initialize the tracker if we are the master
|
||||
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
|
||||
|
||||
train(dataloaders, decoder,
|
||||
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
||||
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
||||
accelerator.print(f"Number of parameters: {num_parameters}")
|
||||
train(dataloaders, decoder, accelerator,
|
||||
tracker=tracker,
|
||||
inference_device=device,
|
||||
inference_device=accelerator.device,
|
||||
load_config=config.load,
|
||||
evaluate_config=config.evaluate,
|
||||
**config.train.dict(),
|
||||
)
|
||||
|
||||
|
||||
# Create a simple click command line interface to load the config and start the training
|
||||
@click.command()
|
||||
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
||||
def main(config_file):
|
||||
print("Recalling config from {}".format(config_file))
|
||||
config = TrainDecoderConfig.from_json_path(config_file)
|
||||
initialize_training(config)
|
||||
|
||||
config_file_path = Path(config_file)
|
||||
config = TrainDecoderConfig.from_json_path(str(config_file_path))
|
||||
initialize_training(config, config_path=config_file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user