mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
Improved upsampler training (#181)
Sampling is now possible without the first decoder unet Non-training unets are deleted in the decoder trainer since they are never used and it is harder merge the models is they have keys in this state dict Fixed a mistake where clip was not re-added after saving
This commit is contained in:
@@ -498,23 +498,27 @@ class DecoderTrainer(nn.Module):
|
||||
warmup_schedulers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
group_wd_params = group_wd_params,
|
||||
**kwargs
|
||||
)
|
||||
if isinstance(unet, nn.Identity):
|
||||
optimizers.append(None)
|
||||
schedulers.append(None)
|
||||
warmup_schedulers.append(None)
|
||||
else:
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
group_wd_params = group_wd_params,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
optimizers.append(optimizer)
|
||||
optimizers.append(optimizer)
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
|
||||
schedulers.append(scheduler)
|
||||
schedulers.append(scheduler)
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||
@@ -590,7 +594,8 @@ class DecoderTrainer(nn.Module):
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
save_obj = {**save_obj, optimizer_key: self.accelerator.unwrap_model(optimizer).state_dict()}
|
||||
state_dict = optimizer.state_dict() if optimizer is not None else None
|
||||
save_obj = {**save_obj, optimizer_key: state_dict}
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||
@@ -612,8 +617,8 @@ class DecoderTrainer(nn.Module):
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
warmup_scheduler = self.warmup_schedulers[ind]
|
||||
|
||||
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
||||
if optimizer is not None:
|
||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if exists(warmup_scheduler):
|
||||
warmup_scheduler.last_step = last_step
|
||||
@@ -714,23 +719,32 @@ class DecoderTrainer(nn.Module):
|
||||
*args,
|
||||
unet_number = None,
|
||||
max_batch_size = None,
|
||||
return_lowres_cond_image=False,
|
||||
**kwargs
|
||||
):
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
|
||||
total_loss = 0.
|
||||
|
||||
|
||||
using_amp = self.accelerator.mixed_precision != 'no'
|
||||
|
||||
cond_images = []
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)
|
||||
# loss_obj may be a tuple with loss and cond_image
|
||||
if return_lowres_cond_image:
|
||||
loss, cond_image = loss_obj
|
||||
else:
|
||||
loss = loss_obj
|
||||
cond_image = None
|
||||
loss = loss * chunk_size_frac
|
||||
if cond_image is not None:
|
||||
cond_images.append(cond_image)
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if self.training:
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
return total_loss
|
||||
if return_lowres_cond_image:
|
||||
return total_loss, torch.stack(cond_images)
|
||||
else:
|
||||
return total_loss
|
||||
|
||||
Reference in New Issue
Block a user