Improved upsampler training (#181)

Sampling is now possible without the first decoder unet

Non-training unets are deleted in the decoder trainer since they are never used and it is harder merge the models is they have keys in this state dict

Fixed a mistake where clip was not re-added after saving
This commit is contained in:
Aidan Dempster
2022-07-19 22:07:50 -04:00
committed by GitHub
parent 4b912a38c6
commit 4145474bab
6 changed files with 104 additions and 49 deletions

View File

@@ -69,6 +69,7 @@ Settings for controlling the training hyperparameters.
| `wd` | No | `0.01` | The weight decay. | | `wd` | No | `0.01` | The weight decay. |
| `max_grad_norm`| No | `0.5` | The grad norm clipping. | | `max_grad_norm`| No | `0.5` | The grad norm clipping. |
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. | | `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
| `cond_scale` | No | `1.0` | Conditioning scale to use for sampling. Can also be an array of values, one for each unet. |
| `device` | No | `cuda:0` | The device to train on. | | `device` | No | `cuda:0` | The device to train on. |
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. | | `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. | | `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |

View File

@@ -75,6 +75,8 @@ def cast_tuple(val, length = None, validate = True):
return out return out
def module_device(module): def module_device(module):
if isinstance(module, nn.Identity):
return 'cpu' # It doesn't matter
return next(module.parameters()).device return next(module.parameters()).device
def zero_init_(m): def zero_init_(m):
@@ -2326,7 +2328,7 @@ class Decoder(nn.Module):
@property @property
def condition_on_text_encodings(self): def condition_on_text_encodings(self):
return any([unet.cond_on_text_encodings for unet in self.unets]) return any([unet.cond_on_text_encodings for unet in self.unets if isinstance(unet, Unet)])
def get_unet(self, unet_number): def get_unet(self, unet_number):
assert 0 < unet_number <= self.num_unets assert 0 < unet_number <= self.num_unets
@@ -2646,11 +2648,13 @@ class Decoder(nn.Module):
@eval_decorator @eval_decorator
def sample( def sample(
self, self,
image = None,
image_embed = None, image_embed = None,
text = None, text = None,
text_encodings = None, text_encodings = None,
batch_size = 1, batch_size = 1,
cond_scale = 1., cond_scale = 1.,
start_at_unet_number = 1,
stop_at_unet_number = None, stop_at_unet_number = None,
distributed = False, distributed = False,
inpaint_image = None, inpaint_image = None,
@@ -2671,14 +2675,22 @@ class Decoder(nn.Module):
assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting' assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'
img = None img = None
if start_at_unet_number > 1:
# Then we are not generating the first image and one must have been passed in
assert exists(image), 'image must be passed in if starting at unet number > 1'
assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)
prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]
img = resize_image_to(image, prev_unet_output_size, nearest = True)
is_cuda = next(self.parameters()).is_cuda is_cuda = next(self.parameters()).is_cuda
num_unets = self.num_unets num_unets = self.num_unets
cond_scale = cast_tuple(cond_scale, num_unets) cond_scale = cast_tuple(cond_scale, num_unets)
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)): for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
if unet_number < start_at_unet_number:
continue # It's the easiest way to do it
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context() context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
with context: with context:
# prepare low resolution conditioning for upsamplers # prepare low resolution conditioning for upsamplers

View File

@@ -530,11 +530,14 @@ class Tracker:
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
prior: DiffusionPrior = trainer.unwrap_model(prior) prior: DiffusionPrior = trainer.unwrap_model(prior)
# Remove CLIP if it is part of the model # Remove CLIP if it is part of the model
original_clip = prior.clip
prior.clip = None prior.clip = None
model_state_dict = prior.state_dict() model_state_dict = prior.state_dict()
prior.clip = original_clip
elif isinstance(trainer, DecoderTrainer): elif isinstance(trainer, DecoderTrainer):
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder) decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
# Remove CLIP if it is part of the model # Remove CLIP if it is part of the model
original_clip = decoder.clip
decoder.clip = None decoder.clip = None
if trainer.use_ema: if trainer.use_ema:
trainable_unets = decoder.unets trainable_unets = decoder.unets
@@ -543,6 +546,7 @@ class Tracker:
decoder.unets = trainable_unets # Swap back decoder.unets = trainable_unets # Swap back
else: else:
model_state_dict = decoder.state_dict() model_state_dict = decoder.state_dict()
decoder.clip = original_clip
else: else:
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?') raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
state_dict = { state_dict = {

View File

@@ -306,9 +306,11 @@ class DecoderTrainConfig(BaseModel):
max_grad_norm: SingularOrIterable(float) = 0.5 max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000 save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
cond_scale: Union[float, List[float]] = 1.0
device: str = 'cuda:0' device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation. validation_samples: int = None # Same as above but for validation.
save_immediately: bool = False
use_ema: bool = True use_ema: bool = True
ema_beta: float = 0.999 ema_beta: float = 0.999
amp: bool = False amp: bool = False

View File

@@ -498,6 +498,11 @@ class DecoderTrainer(nn.Module):
warmup_schedulers = [] warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps): for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
if isinstance(unet, nn.Identity):
optimizers.append(None)
schedulers.append(None)
warmup_schedulers.append(None)
else:
optimizer = get_optimizer( optimizer = get_optimizer(
unet.parameters(), unet.parameters(),
lr = unet_lr, lr = unet_lr,
@@ -508,7 +513,6 @@ class DecoderTrainer(nn.Module):
) )
optimizers.append(optimizer) optimizers.append(optimizer)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
@@ -590,7 +594,8 @@ class DecoderTrainer(nn.Module):
for ind in range(0, self.num_unets): for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}' optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key) optimizer = getattr(self, optimizer_key)
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()} state_dict = optimizer.state_dict() if optimizer is not None else None
save_obj = {**save_obj, optimizer_key: state_dict}
if self.use_ema: if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -612,8 +617,8 @@ class DecoderTrainer(nn.Module):
optimizer_key = f'optim{ind}' optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key) optimizer = getattr(self, optimizer_key)
warmup_scheduler = self.warmup_schedulers[ind] warmup_scheduler = self.warmup_schedulers[ind]
if optimizer is not None:
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key]) optimizer.load_state_dict(loaded_obj[optimizer_key])
if exists(warmup_scheduler): if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step warmup_scheduler.last_step = last_step
@@ -714,23 +719,32 @@ class DecoderTrainer(nn.Module):
*args, *args,
unet_number = None, unet_number = None,
max_batch_size = None, max_batch_size = None,
return_lowres_cond_image=False,
**kwargs **kwargs
): ):
unet_number = self.validate_and_return_unet_number(unet_number) unet_number = self.validate_and_return_unet_number(unet_number)
total_loss = 0. total_loss = 0.
cond_images = []
using_amp = self.accelerator.mixed_precision != 'no'
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast(): with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs) loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)
# loss_obj may be a tuple with loss and cond_image
if return_lowres_cond_image:
loss, cond_image = loss_obj
else:
loss = loss_obj
cond_image = None
loss = loss * chunk_size_frac loss = loss * chunk_size_frac
if cond_image is not None:
cond_images.append(cond_image)
total_loss += loss.item() total_loss += loss.item()
if self.training: if self.training:
self.accelerator.backward(loss) self.accelerator.backward(loss)
if return_lowres_cond_image:
return total_loss, torch.stack(cond_images)
else:
return total_loss return total_loss

View File

@@ -1,5 +1,6 @@
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from datetime import timedelta
from dalle2_pytorch.trainer import DecoderTrainer from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
@@ -11,11 +12,12 @@ from clip import tokenize
import torchvision import torchvision
import torch import torch
from torch import nn
from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from accelerate import Accelerator, DistributedDataParallelKwargs from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs
from accelerate.utils import dataclasses as accelerate_dataclasses from accelerate.utils import dataclasses as accelerate_dataclasses
import webdataset as wds import webdataset as wds
import click import click
@@ -132,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
break break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n])) return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True): def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
""" """
Takes example data and generates images from the embeddings Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions Returns three lists: real images, generated images, and captions
@@ -157,6 +159,13 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
# Then we are using precomputed text embeddings # Then we are using precomputed text embeddings
text_embeddings = torch.stack(text_embeddings) text_embeddings = torch.stack(text_embeddings)
sample_params["text_encodings"] = text_embeddings sample_params["text_encodings"] = text_embeddings
sample_params["start_at_unet_number"] = start_unet
sample_params["stop_at_unet_number"] = end_unet
if start_unet > 1:
# If we are only training upsamplers
sample_params["image"] = torch.stack(real_images)
if device is not None:
sample_params["_device"] = device
samples = trainer.sample(**sample_params) samples = trainer.sample(**sample_params)
generated_images = list(samples) generated_images = list(samples)
captions = [text_prepend + txt for txt in txts] captions = [text_prepend + txt for txt in txts]
@@ -165,15 +174,15 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images] real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
return real_images, generated_images, captions return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""): def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
""" """
Generates samples and uses torchvision to put them in a side by side grid for easy viewing Generates samples and uses torchvision to put them in a side by side grid for easy viewing
""" """
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend) real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
""" """
Computes evaluation metrics for the decoder Computes evaluation metrics for the decoder
""" """
@@ -183,7 +192,7 @@ def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=Fa
if len(examples) == 0: if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.") print("No data to evaluate. Check that your dataloader has shards.")
return metrics return metrics
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings) real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float) real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8 # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -259,11 +268,13 @@ def train(
evaluate_config=None, evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None, validation_samples = None,
save_immediately=False,
epochs = 20, epochs = 20,
n_sample_images = 5, n_sample_images = 5,
save_every_n_samples = 100000, save_every_n_samples = 100000,
unet_training_mask=None, unet_training_mask=None,
condition_on_text_encodings=False, condition_on_text_encodings=False,
cond_scale=1.0,
**kwargs **kwargs
): ):
""" """
@@ -271,6 +282,21 @@ def train(
""" """
is_master = accelerator.process_index == 0 is_master = accelerator.process_index == 0
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * len(decoder.unets)
assert len(unet_training_mask) == len(decoder.unets), f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
trainable_unet_numbers = [i+1 for i, trainable in enumerate(unet_training_mask) if trainable]
first_trainable_unet = trainable_unet_numbers[0]
last_trainable_unet = trainable_unet_numbers[-1]
def move_unets(unet_training_mask):
for i in range(len(decoder.unets)):
if not unet_training_mask[i]:
# Replace the unet from the module list with a nn.Identity(). This training script never uses unets that aren't being trained so this is fine.
decoder.unets[i] = nn.Identity().to(inference_device)
# Remove non-trainable unets
move_unets(unet_training_mask)
trainer = DecoderTrainer( trainer = DecoderTrainer(
decoder=decoder, decoder=decoder,
accelerator=accelerator, accelerator=accelerator,
@@ -285,6 +311,7 @@ def train(
sample = 0 sample = 0
samples_seen = 0 samples_seen = 0
val_sample = 0 val_sample = 0
step = lambda: int(trainer.num_steps_taken(unet_number=first_trainable_unet))
if tracker.can_recall: if tracker.can_recall:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer) start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
@@ -296,13 +323,6 @@ def train(
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}") accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
trainer.to(device=inference_device) trainer.to(device=inference_device)
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
accelerator.print(print_ribbon("Generating Example Data", repeat=40)) accelerator.print(print_ribbon("Generating Example Data", repeat=40))
accelerator.print("This can take a while to load the shard lists...") accelerator.print("This can take a while to load the shard lists...")
if is_master: if is_master:
@@ -360,7 +380,7 @@ def train(
tokenized_texts = tokenize(txt, truncate=True) tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet) loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
trainer.update(unet_number=unet) trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
@@ -373,10 +393,10 @@ def train(
unet_all_losses = accelerator.gather(unet_losses_tensor) unet_all_losses = accelerator.gather(unet_losses_tensor)
mask = unet_all_losses != 0 mask = unet_all_losses != 0
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0) unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 } loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if unet_training_mask[index] }
# gather decay rate on each UNet # gather decay rate on each UNet
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)} ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets) if unet_training_mask[index]}
log_data = { log_data = {
"Epoch": epoch, "Epoch": epoch,
@@ -391,7 +411,7 @@ def train(
if is_master: if is_master:
tracker.log(log_data, step=step()) tracker.log(log_data, step=step())
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope if is_master and (last_snapshot + save_every_n_samples < sample or (save_immediately and i == 0)): # This will miss by some amount every time, but it's not a big deal... I hope
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master # It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
print("Saving snapshot") print("Saving snapshot")
last_snapshot = sample last_snapshot = sample
@@ -399,7 +419,7 @@ def train(
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen) save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
if exists(n_sample_images) and n_sample_images > 0: if exists(n_sample_images) and n_sample_images > 0:
trainer.eval() trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples: if epoch_samples is not None and sample >= epoch_samples:
@@ -449,8 +469,9 @@ def train(
else: else:
# Then we need to pass the text instead # Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True) tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts forward_params['text'] = tokenized_texts
loss = trainer.forward(img.float(), **forward_params, unet_number=unet) loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
average_val_loss_tensor[0, unet-1] += loss average_val_loss_tensor[0, unet-1] += loss
if i % VALID_CALC_LOSS_EVERY_ITERS == 0: if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
@@ -477,7 +498,7 @@ def train(
if next_task == 'eval': if next_task == 'eval':
if exists(evaluate_config): if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings) evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
if is_master: if is_master:
tracker.log(evaluation, step=step()) tracker.log(evaluation, step=step())
next_task = 'sample' next_task = 'sample'
@@ -488,15 +509,15 @@ def train(
# Generate examples and save the model if we are the master # Generate examples and save the model if we are the master
# Generate sample images # Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40)) print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ") test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step()) tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
print(print_ribbon(f"Starting Saving {epoch}", repeat=40)) print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
is_best = False is_best = False
if all_average_val_losses is not None: if all_average_val_losses is not None:
average_loss = all_average_val_losses.mean(dim=0).item() average_loss = all_average_val_losses.mean(dim=0).sum() / sum(unet_training_mask)
if len(validation_losses) == 0 or average_loss < min(validation_losses): if len(validation_losses) == 0 or average_loss < min(validation_losses):
is_best = True is_best = True
validation_losses.append(average_loss) validation_losses.append(average_loss)
@@ -522,7 +543,8 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Set up accelerator for configurable distributed training # Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
# We are using distributed training and want to immediately ensure all can connect # We are using distributed training and want to immediately ensure all can connect