mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
be able to turn off p2 loss reweighting for upsamplers
This commit is contained in:
@@ -1866,14 +1866,17 @@ class Decoder(nn.Module):
|
|||||||
if not exists(beta_schedule):
|
if not exists(beta_schedule):
|
||||||
beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))
|
beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))
|
||||||
|
|
||||||
|
beta_schedule = cast_tuple(beta_schedule, num_unets)
|
||||||
|
p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)
|
||||||
|
|
||||||
self.noise_schedulers = nn.ModuleList([])
|
self.noise_schedulers = nn.ModuleList([])
|
||||||
|
|
||||||
for unet_beta_schedule in beta_schedule:
|
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
|
||||||
noise_scheduler = NoiseScheduler(
|
noise_scheduler = NoiseScheduler(
|
||||||
beta_schedule = unet_beta_schedule,
|
beta_schedule = unet_beta_schedule,
|
||||||
timesteps = timesteps,
|
timesteps = timesteps,
|
||||||
loss_type = loss_type,
|
loss_type = loss_type,
|
||||||
p2_loss_weight_gamma = p2_loss_weight_gamma,
|
p2_loss_weight_gamma = unet_p2_loss_weight_gamma,
|
||||||
p2_loss_weight_k = p2_loss_weight_k
|
p2_loss_weight_k = p2_loss_weight_k
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.11.1'
|
__version__ = '0.11.2'
|
||||||
|
|||||||
Reference in New Issue
Block a user