allow for overriding use of EMA during sampling in decoder trainer with use_non_ema keyword, also fix some issues with automatic normalization of images and low res conditioning image if latent diffusion is in play

This commit is contained in:
Phil Wang
2022-05-16 11:18:30 -07:00
parent 1212f7058d
commit f4016f6302
4 changed files with 80 additions and 10 deletions

View File

@@ -179,8 +179,8 @@ class EMA(nn.Module):
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.update_after_step = update_after_step # only start EMA after this step number, starting at 0
self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
@@ -189,6 +189,9 @@ class EMA(nn.Module):
device = self.initted.device
self.ema_model.to(device)
def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict())
def update(self):
self.step += 1
@@ -196,7 +199,7 @@ class EMA(nn.Module):
return
if not self.initted:
self.ema_model.state_dict(self.online_model.state_dict())
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
@@ -405,6 +408,9 @@ class DecoderTrainer(nn.Module):
@torch.no_grad()
@cast_torch_tensor
def sample(self, *args, **kwargs):
if kwargs.pop('use_non_ema', False):
return self.decoder.sample(*args, **kwargs)
if self.use_ema:
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling