mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Improved upsampler training (#181)
Sampling is now possible without the first decoder unet Non-training unets are deleted in the decoder trainer since they are never used and it is harder merge the models is they have keys in this state dict Fixed a mistake where clip was not re-added after saving
This commit is contained in:
@@ -69,6 +69,7 @@ Settings for controlling the training hyperparameters.
|
|||||||
| `wd` | No | `0.01` | The weight decay. |
|
| `wd` | No | `0.01` | The weight decay. |
|
||||||
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
|
| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
|
||||||
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
|
| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
|
||||||
|
| `cond_scale` | No | `1.0` | Conditioning scale to use for sampling. Can also be an array of values, one for each unet. |
|
||||||
| `device` | No | `cuda:0` | The device to train on. |
|
| `device` | No | `cuda:0` | The device to train on. |
|
||||||
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
|
| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
|
||||||
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
|
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
|
||||||
|
|||||||
@@ -75,6 +75,8 @@ def cast_tuple(val, length = None, validate = True):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def module_device(module):
|
def module_device(module):
|
||||||
|
if isinstance(module, nn.Identity):
|
||||||
|
return 'cpu' # It doesn't matter
|
||||||
return next(module.parameters()).device
|
return next(module.parameters()).device
|
||||||
|
|
||||||
def zero_init_(m):
|
def zero_init_(m):
|
||||||
@@ -2326,7 +2328,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def condition_on_text_encodings(self):
|
def condition_on_text_encodings(self):
|
||||||
return any([unet.cond_on_text_encodings for unet in self.unets])
|
return any([unet.cond_on_text_encodings for unet in self.unets if isinstance(unet, Unet)])
|
||||||
|
|
||||||
def get_unet(self, unet_number):
|
def get_unet(self, unet_number):
|
||||||
assert 0 < unet_number <= self.num_unets
|
assert 0 < unet_number <= self.num_unets
|
||||||
@@ -2646,11 +2648,13 @@ class Decoder(nn.Module):
|
|||||||
@eval_decorator
|
@eval_decorator
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
|
image = None,
|
||||||
image_embed = None,
|
image_embed = None,
|
||||||
text = None,
|
text = None,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
batch_size = 1,
|
batch_size = 1,
|
||||||
cond_scale = 1.,
|
cond_scale = 1.,
|
||||||
|
start_at_unet_number = 1,
|
||||||
stop_at_unet_number = None,
|
stop_at_unet_number = None,
|
||||||
distributed = False,
|
distributed = False,
|
||||||
inpaint_image = None,
|
inpaint_image = None,
|
||||||
@@ -2671,14 +2675,22 @@ class Decoder(nn.Module):
|
|||||||
assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'
|
assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'
|
||||||
|
|
||||||
img = None
|
img = None
|
||||||
|
if start_at_unet_number > 1:
|
||||||
|
# Then we are not generating the first image and one must have been passed in
|
||||||
|
assert exists(image), 'image must be passed in if starting at unet number > 1'
|
||||||
|
assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)
|
||||||
|
prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]
|
||||||
|
img = resize_image_to(image, prev_unet_output_size, nearest = True)
|
||||||
is_cuda = next(self.parameters()).is_cuda
|
is_cuda = next(self.parameters()).is_cuda
|
||||||
|
|
||||||
num_unets = self.num_unets
|
num_unets = self.num_unets
|
||||||
cond_scale = cast_tuple(cond_scale, num_unets)
|
cond_scale = cast_tuple(cond_scale, num_unets)
|
||||||
|
|
||||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
||||||
|
if unet_number < start_at_unet_number:
|
||||||
|
continue # It's the easiest way to do it
|
||||||
|
|
||||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
||||||
|
|
||||||
with context:
|
with context:
|
||||||
# prepare low resolution conditioning for upsamplers
|
# prepare low resolution conditioning for upsamplers
|
||||||
|
|||||||
@@ -530,11 +530,14 @@ class Tracker:
|
|||||||
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||||
prior: DiffusionPrior = trainer.unwrap_model(prior)
|
prior: DiffusionPrior = trainer.unwrap_model(prior)
|
||||||
# Remove CLIP if it is part of the model
|
# Remove CLIP if it is part of the model
|
||||||
|
original_clip = prior.clip
|
||||||
prior.clip = None
|
prior.clip = None
|
||||||
model_state_dict = prior.state_dict()
|
model_state_dict = prior.state_dict()
|
||||||
|
prior.clip = original_clip
|
||||||
elif isinstance(trainer, DecoderTrainer):
|
elif isinstance(trainer, DecoderTrainer):
|
||||||
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||||
# Remove CLIP if it is part of the model
|
# Remove CLIP if it is part of the model
|
||||||
|
original_clip = decoder.clip
|
||||||
decoder.clip = None
|
decoder.clip = None
|
||||||
if trainer.use_ema:
|
if trainer.use_ema:
|
||||||
trainable_unets = decoder.unets
|
trainable_unets = decoder.unets
|
||||||
@@ -543,6 +546,7 @@ class Tracker:
|
|||||||
decoder.unets = trainable_unets # Swap back
|
decoder.unets = trainable_unets # Swap back
|
||||||
else:
|
else:
|
||||||
model_state_dict = decoder.state_dict()
|
model_state_dict = decoder.state_dict()
|
||||||
|
decoder.clip = original_clip
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
|
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
|
||||||
state_dict = {
|
state_dict = {
|
||||||
|
|||||||
@@ -306,9 +306,11 @@ class DecoderTrainConfig(BaseModel):
|
|||||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||||
save_every_n_samples: int = 100000
|
save_every_n_samples: int = 100000
|
||||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||||
|
cond_scale: Union[float, List[float]] = 1.0
|
||||||
device: str = 'cuda:0'
|
device: str = 'cuda:0'
|
||||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||||
validation_samples: int = None # Same as above but for validation.
|
validation_samples: int = None # Same as above but for validation.
|
||||||
|
save_immediately: bool = False
|
||||||
use_ema: bool = True
|
use_ema: bool = True
|
||||||
ema_beta: float = 0.999
|
ema_beta: float = 0.999
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
|
|||||||
@@ -498,6 +498,11 @@ class DecoderTrainer(nn.Module):
|
|||||||
warmup_schedulers = []
|
warmup_schedulers = []
|
||||||
|
|
||||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
||||||
|
if isinstance(unet, nn.Identity):
|
||||||
|
optimizers.append(None)
|
||||||
|
schedulers.append(None)
|
||||||
|
warmup_schedulers.append(None)
|
||||||
|
else:
|
||||||
optimizer = get_optimizer(
|
optimizer = get_optimizer(
|
||||||
unet.parameters(),
|
unet.parameters(),
|
||||||
lr = unet_lr,
|
lr = unet_lr,
|
||||||
@@ -508,7 +513,6 @@ class DecoderTrainer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
optimizers.append(optimizer)
|
optimizers.append(optimizer)
|
||||||
|
|
||||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||||
|
|
||||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||||
@@ -590,7 +594,8 @@ class DecoderTrainer(nn.Module):
|
|||||||
for ind in range(0, self.num_unets):
|
for ind in range(0, self.num_unets):
|
||||||
optimizer_key = f'optim{ind}'
|
optimizer_key = f'optim{ind}'
|
||||||
optimizer = getattr(self, optimizer_key)
|
optimizer = getattr(self, optimizer_key)
|
||||||
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
|
state_dict = optimizer.state_dict() if optimizer is not None else None
|
||||||
|
save_obj = {**save_obj, optimizer_key: state_dict}
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||||
@@ -612,8 +617,8 @@ class DecoderTrainer(nn.Module):
|
|||||||
optimizer_key = f'optim{ind}'
|
optimizer_key = f'optim{ind}'
|
||||||
optimizer = getattr(self, optimizer_key)
|
optimizer = getattr(self, optimizer_key)
|
||||||
warmup_scheduler = self.warmup_schedulers[ind]
|
warmup_scheduler = self.warmup_schedulers[ind]
|
||||||
|
if optimizer is not None:
|
||||||
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||||
|
|
||||||
if exists(warmup_scheduler):
|
if exists(warmup_scheduler):
|
||||||
warmup_scheduler.last_step = last_step
|
warmup_scheduler.last_step = last_step
|
||||||
@@ -714,23 +719,32 @@ class DecoderTrainer(nn.Module):
|
|||||||
*args,
|
*args,
|
||||||
unet_number = None,
|
unet_number = None,
|
||||||
max_batch_size = None,
|
max_batch_size = None,
|
||||||
|
return_lowres_cond_image=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||||
|
|
||||||
total_loss = 0.
|
total_loss = 0.
|
||||||
|
cond_images = []
|
||||||
|
|
||||||
using_amp = self.accelerator.mixed_precision != 'no'
|
|
||||||
|
|
||||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||||
with self.accelerator.autocast():
|
with self.accelerator.autocast():
|
||||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)
|
||||||
|
# loss_obj may be a tuple with loss and cond_image
|
||||||
|
if return_lowres_cond_image:
|
||||||
|
loss, cond_image = loss_obj
|
||||||
|
else:
|
||||||
|
loss = loss_obj
|
||||||
|
cond_image = None
|
||||||
loss = loss * chunk_size_frac
|
loss = loss * chunk_size_frac
|
||||||
|
if cond_image is not None:
|
||||||
|
cond_images.append(cond_image)
|
||||||
|
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
self.accelerator.backward(loss)
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
|
if return_lowres_cond_image:
|
||||||
|
return total_loss, torch.stack(cond_images)
|
||||||
|
else:
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
from dalle2_pytorch.trainer import DecoderTrainer
|
from dalle2_pytorch.trainer import DecoderTrainer
|
||||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||||
@@ -11,11 +12,12 @@ from clip import tokenize
|
|||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||||
from torchmetrics.image.inception import InceptionScore
|
from torchmetrics.image.inception import InceptionScore
|
||||||
from torchmetrics.image.kid import KernelInceptionDistance
|
from torchmetrics.image.kid import KernelInceptionDistance
|
||||||
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
||||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs
|
||||||
from accelerate.utils import dataclasses as accelerate_dataclasses
|
from accelerate.utils import dataclasses as accelerate_dataclasses
|
||||||
import webdataset as wds
|
import webdataset as wds
|
||||||
import click
|
import click
|
||||||
@@ -132,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
|
|||||||
break
|
break
|
||||||
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||||
|
|
||||||
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True):
|
def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
|
||||||
"""
|
"""
|
||||||
Takes example data and generates images from the embeddings
|
Takes example data and generates images from the embeddings
|
||||||
Returns three lists: real images, generated images, and captions
|
Returns three lists: real images, generated images, and captions
|
||||||
@@ -157,6 +159,13 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
|
|||||||
# Then we are using precomputed text embeddings
|
# Then we are using precomputed text embeddings
|
||||||
text_embeddings = torch.stack(text_embeddings)
|
text_embeddings = torch.stack(text_embeddings)
|
||||||
sample_params["text_encodings"] = text_embeddings
|
sample_params["text_encodings"] = text_embeddings
|
||||||
|
sample_params["start_at_unet_number"] = start_unet
|
||||||
|
sample_params["stop_at_unet_number"] = end_unet
|
||||||
|
if start_unet > 1:
|
||||||
|
# If we are only training upsamplers
|
||||||
|
sample_params["image"] = torch.stack(real_images)
|
||||||
|
if device is not None:
|
||||||
|
sample_params["_device"] = device
|
||||||
samples = trainer.sample(**sample_params)
|
samples = trainer.sample(**sample_params)
|
||||||
generated_images = list(samples)
|
generated_images = list(samples)
|
||||||
captions = [text_prepend + txt for txt in txts]
|
captions = [text_prepend + txt for txt in txts]
|
||||||
@@ -165,15 +174,15 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
|
|||||||
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
|
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
|
||||||
return real_images, generated_images, captions
|
return real_images, generated_images, captions
|
||||||
|
|
||||||
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
|
def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
|
||||||
"""
|
"""
|
||||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||||
"""
|
"""
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
|
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
|
||||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||||
return grid_images, captions
|
return grid_images, captions
|
||||||
|
|
||||||
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||||
"""
|
"""
|
||||||
Computes evaluation metrics for the decoder
|
Computes evaluation metrics for the decoder
|
||||||
"""
|
"""
|
||||||
@@ -183,7 +192,7 @@ def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=Fa
|
|||||||
if len(examples) == 0:
|
if len(examples) == 0:
|
||||||
print("No data to evaluate. Check that your dataloader has shards.")
|
print("No data to evaluate. Check that your dataloader has shards.")
|
||||||
return metrics
|
return metrics
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
|
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
|
||||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||||
@@ -259,11 +268,13 @@ def train(
|
|||||||
evaluate_config=None,
|
evaluate_config=None,
|
||||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||||
validation_samples = None,
|
validation_samples = None,
|
||||||
|
save_immediately=False,
|
||||||
epochs = 20,
|
epochs = 20,
|
||||||
n_sample_images = 5,
|
n_sample_images = 5,
|
||||||
save_every_n_samples = 100000,
|
save_every_n_samples = 100000,
|
||||||
unet_training_mask=None,
|
unet_training_mask=None,
|
||||||
condition_on_text_encodings=False,
|
condition_on_text_encodings=False,
|
||||||
|
cond_scale=1.0,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -271,6 +282,21 @@ def train(
|
|||||||
"""
|
"""
|
||||||
is_master = accelerator.process_index == 0
|
is_master = accelerator.process_index == 0
|
||||||
|
|
||||||
|
if not exists(unet_training_mask):
|
||||||
|
# Then the unet mask should be true for all unets in the decoder
|
||||||
|
unet_training_mask = [True] * len(decoder.unets)
|
||||||
|
assert len(unet_training_mask) == len(decoder.unets), f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||||
|
trainable_unet_numbers = [i+1 for i, trainable in enumerate(unet_training_mask) if trainable]
|
||||||
|
first_trainable_unet = trainable_unet_numbers[0]
|
||||||
|
last_trainable_unet = trainable_unet_numbers[-1]
|
||||||
|
def move_unets(unet_training_mask):
|
||||||
|
for i in range(len(decoder.unets)):
|
||||||
|
if not unet_training_mask[i]:
|
||||||
|
# Replace the unet from the module list with a nn.Identity(). This training script never uses unets that aren't being trained so this is fine.
|
||||||
|
decoder.unets[i] = nn.Identity().to(inference_device)
|
||||||
|
# Remove non-trainable unets
|
||||||
|
move_unets(unet_training_mask)
|
||||||
|
|
||||||
trainer = DecoderTrainer(
|
trainer = DecoderTrainer(
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
accelerator=accelerator,
|
accelerator=accelerator,
|
||||||
@@ -285,6 +311,7 @@ def train(
|
|||||||
sample = 0
|
sample = 0
|
||||||
samples_seen = 0
|
samples_seen = 0
|
||||||
val_sample = 0
|
val_sample = 0
|
||||||
|
step = lambda: int(trainer.num_steps_taken(unet_number=first_trainable_unet))
|
||||||
|
|
||||||
if tracker.can_recall:
|
if tracker.can_recall:
|
||||||
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||||
@@ -296,13 +323,6 @@ def train(
|
|||||||
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
|
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
|
||||||
trainer.to(device=inference_device)
|
trainer.to(device=inference_device)
|
||||||
|
|
||||||
if not exists(unet_training_mask):
|
|
||||||
# Then the unet mask should be true for all unets in the decoder
|
|
||||||
unet_training_mask = [True] * trainer.num_unets
|
|
||||||
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
|
|
||||||
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
|
|
||||||
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
|
||||||
|
|
||||||
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
|
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
|
||||||
accelerator.print("This can take a while to load the shard lists...")
|
accelerator.print("This can take a while to load the shard lists...")
|
||||||
if is_master:
|
if is_master:
|
||||||
@@ -360,7 +380,7 @@ def train(
|
|||||||
tokenized_texts = tokenize(txt, truncate=True)
|
tokenized_texts = tokenize(txt, truncate=True)
|
||||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||||
forward_params['text'] = tokenized_texts
|
forward_params['text'] = tokenized_texts
|
||||||
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
|
||||||
trainer.update(unet_number=unet)
|
trainer.update(unet_number=unet)
|
||||||
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||||
|
|
||||||
@@ -373,10 +393,10 @@ def train(
|
|||||||
unet_all_losses = accelerator.gather(unet_losses_tensor)
|
unet_all_losses = accelerator.gather(unet_losses_tensor)
|
||||||
mask = unet_all_losses != 0
|
mask = unet_all_losses != 0
|
||||||
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
|
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
|
||||||
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
|
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if unet_training_mask[index] }
|
||||||
|
|
||||||
# gather decay rate on each UNet
|
# gather decay rate on each UNet
|
||||||
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
|
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets) if unet_training_mask[index]}
|
||||||
|
|
||||||
log_data = {
|
log_data = {
|
||||||
"Epoch": epoch,
|
"Epoch": epoch,
|
||||||
@@ -391,7 +411,7 @@ def train(
|
|||||||
if is_master:
|
if is_master:
|
||||||
tracker.log(log_data, step=step())
|
tracker.log(log_data, step=step())
|
||||||
|
|
||||||
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
if is_master and (last_snapshot + save_every_n_samples < sample or (save_immediately and i == 0)): # This will miss by some amount every time, but it's not a big deal... I hope
|
||||||
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
|
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
|
||||||
print("Saving snapshot")
|
print("Saving snapshot")
|
||||||
last_snapshot = sample
|
last_snapshot = sample
|
||||||
@@ -399,7 +419,7 @@ def train(
|
|||||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
||||||
if exists(n_sample_images) and n_sample_images > 0:
|
if exists(n_sample_images) and n_sample_images > 0:
|
||||||
trainer.eval()
|
trainer.eval()
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||||
|
|
||||||
if epoch_samples is not None and sample >= epoch_samples:
|
if epoch_samples is not None and sample >= epoch_samples:
|
||||||
@@ -449,8 +469,9 @@ def train(
|
|||||||
else:
|
else:
|
||||||
# Then we need to pass the text instead
|
# Then we need to pass the text instead
|
||||||
tokenized_texts = tokenize(txt, truncate=True)
|
tokenized_texts = tokenize(txt, truncate=True)
|
||||||
|
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||||
forward_params['text'] = tokenized_texts
|
forward_params['text'] = tokenized_texts
|
||||||
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
|
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
|
||||||
average_val_loss_tensor[0, unet-1] += loss
|
average_val_loss_tensor[0, unet-1] += loss
|
||||||
|
|
||||||
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||||
@@ -477,7 +498,7 @@ def train(
|
|||||||
if next_task == 'eval':
|
if next_task == 'eval':
|
||||||
if exists(evaluate_config):
|
if exists(evaluate_config):
|
||||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
|
||||||
if is_master:
|
if is_master:
|
||||||
tracker.log(evaluation, step=step())
|
tracker.log(evaluation, step=step())
|
||||||
next_task = 'sample'
|
next_task = 'sample'
|
||||||
@@ -488,15 +509,15 @@ def train(
|
|||||||
# Generate examples and save the model if we are the master
|
# Generate examples and save the model if we are the master
|
||||||
# Generate sample images
|
# Generate sample images
|
||||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
|
test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||||
|
|
||||||
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||||
is_best = False
|
is_best = False
|
||||||
if all_average_val_losses is not None:
|
if all_average_val_losses is not None:
|
||||||
average_loss = all_average_val_losses.mean(dim=0).item()
|
average_loss = all_average_val_losses.mean(dim=0).sum() / sum(unet_training_mask)
|
||||||
if len(validation_losses) == 0 or average_loss < min(validation_losses):
|
if len(validation_losses) == 0 or average_loss < min(validation_losses):
|
||||||
is_best = True
|
is_best = True
|
||||||
validation_losses.append(average_loss)
|
validation_losses.append(average_loss)
|
||||||
@@ -522,7 +543,8 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
|||||||
|
|
||||||
# Set up accelerator for configurable distributed training
|
# Set up accelerator for configurable distributed training
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
|
||||||
|
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
|
||||||
|
|
||||||
if accelerator.num_processes > 1:
|
if accelerator.num_processes > 1:
|
||||||
# We are using distributed training and want to immediately ensure all can connect
|
# We are using distributed training and want to immediately ensure all can connect
|
||||||
|
|||||||
Reference in New Issue
Block a user