mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-22 11:04:21 +01:00
Improved upsampler training (#181)
Sampling is now possible without the first decoder unet Non-training unets are deleted in the decoder trainer since they are never used and it is harder merge the models is they have keys in this state dict Fixed a mistake where clip was not re-added after saving
This commit is contained in:
@@ -75,6 +75,8 @@ def cast_tuple(val, length = None, validate = True):
|
||||
return out
|
||||
|
||||
def module_device(module):
|
||||
if isinstance(module, nn.Identity):
|
||||
return 'cpu' # It doesn't matter
|
||||
return next(module.parameters()).device
|
||||
|
||||
def zero_init_(m):
|
||||
@@ -2326,7 +2328,7 @@ class Decoder(nn.Module):
|
||||
|
||||
@property
|
||||
def condition_on_text_encodings(self):
|
||||
return any([unet.cond_on_text_encodings for unet in self.unets])
|
||||
return any([unet.cond_on_text_encodings for unet in self.unets if isinstance(unet, Unet)])
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= self.num_unets
|
||||
@@ -2646,11 +2648,13 @@ class Decoder(nn.Module):
|
||||
@eval_decorator
|
||||
def sample(
|
||||
self,
|
||||
image = None,
|
||||
image_embed = None,
|
||||
text = None,
|
||||
text_encodings = None,
|
||||
batch_size = 1,
|
||||
cond_scale = 1.,
|
||||
start_at_unet_number = 1,
|
||||
stop_at_unet_number = None,
|
||||
distributed = False,
|
||||
inpaint_image = None,
|
||||
@@ -2671,14 +2675,22 @@ class Decoder(nn.Module):
|
||||
assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'
|
||||
|
||||
img = None
|
||||
if start_at_unet_number > 1:
|
||||
# Then we are not generating the first image and one must have been passed in
|
||||
assert exists(image), 'image must be passed in if starting at unet number > 1'
|
||||
assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)
|
||||
prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]
|
||||
img = resize_image_to(image, prev_unet_output_size, nearest = True)
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
num_unets = self.num_unets
|
||||
cond_scale = cast_tuple(cond_scale, num_unets)
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
||||
if unet_number < start_at_unet_number:
|
||||
continue # It's the easiest way to do it
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
||||
|
||||
with context:
|
||||
# prepare low resolution conditioning for upsamplers
|
||||
|
||||
@@ -530,11 +530,14 @@ class Tracker:
|
||||
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||
prior: DiffusionPrior = trainer.unwrap_model(prior)
|
||||
# Remove CLIP if it is part of the model
|
||||
original_clip = prior.clip
|
||||
prior.clip = None
|
||||
model_state_dict = prior.state_dict()
|
||||
prior.clip = original_clip
|
||||
elif isinstance(trainer, DecoderTrainer):
|
||||
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||
# Remove CLIP if it is part of the model
|
||||
original_clip = decoder.clip
|
||||
decoder.clip = None
|
||||
if trainer.use_ema:
|
||||
trainable_unets = decoder.unets
|
||||
@@ -543,6 +546,7 @@ class Tracker:
|
||||
decoder.unets = trainable_unets # Swap back
|
||||
else:
|
||||
model_state_dict = decoder.state_dict()
|
||||
decoder.clip = original_clip
|
||||
else:
|
||||
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
|
||||
state_dict = {
|
||||
|
||||
@@ -306,9 +306,11 @@ class DecoderTrainConfig(BaseModel):
|
||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||
save_every_n_samples: int = 100000
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
cond_scale: Union[float, List[float]] = 1.0
|
||||
device: str = 'cuda:0'
|
||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||
validation_samples: int = None # Same as above but for validation.
|
||||
save_immediately: bool = False
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.999
|
||||
amp: bool = False
|
||||
|
||||
@@ -498,23 +498,27 @@ class DecoderTrainer(nn.Module):
|
||||
warmup_schedulers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
group_wd_params = group_wd_params,
|
||||
**kwargs
|
||||
)
|
||||
if isinstance(unet, nn.Identity):
|
||||
optimizers.append(None)
|
||||
schedulers.append(None)
|
||||
warmup_schedulers.append(None)
|
||||
else:
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
group_wd_params = group_wd_params,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
optimizers.append(optimizer)
|
||||
optimizers.append(optimizer)
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
|
||||
schedulers.append(scheduler)
|
||||
schedulers.append(scheduler)
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||
@@ -590,7 +594,8 @@ class DecoderTrainer(nn.Module):
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
|
||||
state_dict = optimizer.state_dict() if optimizer is not None else None
|
||||
save_obj = {**save_obj, optimizer_key: state_dict}
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||
@@ -612,8 +617,8 @@ class DecoderTrainer(nn.Module):
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
warmup_scheduler = self.warmup_schedulers[ind]
|
||||
|
||||
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
||||
if optimizer is not None:
|
||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if exists(warmup_scheduler):
|
||||
warmup_scheduler.last_step = last_step
|
||||
@@ -714,23 +719,32 @@ class DecoderTrainer(nn.Module):
|
||||
*args,
|
||||
unet_number = None,
|
||||
max_batch_size = None,
|
||||
return_lowres_cond_image=False,
|
||||
**kwargs
|
||||
):
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
|
||||
total_loss = 0.
|
||||
|
||||
|
||||
using_amp = self.accelerator.mixed_precision != 'no'
|
||||
|
||||
cond_images = []
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)
|
||||
# loss_obj may be a tuple with loss and cond_image
|
||||
if return_lowres_cond_image:
|
||||
loss, cond_image = loss_obj
|
||||
else:
|
||||
loss = loss_obj
|
||||
cond_image = None
|
||||
loss = loss * chunk_size_frac
|
||||
if cond_image is not None:
|
||||
cond_images.append(cond_image)
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if self.training:
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
return total_loss
|
||||
if return_lowres_cond_image:
|
||||
return total_loss, torch.stack(cond_images)
|
||||
else:
|
||||
return total_loss
|
||||
|
||||
Reference in New Issue
Block a user