be able to turn off p2 loss reweighting for upsamplers

This commit is contained in:
Phil Wang
2022-06-20 09:43:31 -07:00
parent fc7abf624d
commit f545ce18f4
2 changed files with 6 additions and 3 deletions

View File

@@ -1866,14 +1866,17 @@ class Decoder(nn.Module):
if not exists(beta_schedule):
beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))
beta_schedule = cast_tuple(beta_schedule, num_unets)
p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)
self.noise_schedulers = nn.ModuleList([])
for unet_beta_schedule in beta_schedule:
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule,
timesteps = timesteps,
loss_type = loss_type,
p2_loss_weight_gamma = p2_loss_weight_gamma,
p2_loss_weight_gamma = unet_p2_loss_weight_gamma,
p2_loss_weight_k = p2_loss_weight_k
)