Compare commits

..

3 Commits

Author SHA1 Message Date
Phil Wang
f545ce18f4 be able to turn off p2 loss reweighting for upsamplers 2022-06-20 09:43:31 -07:00
Phil Wang
fc7abf624d in paper, blur sigma was 0.6 2022-06-20 09:05:08 -07:00
Phil Wang
67f0740777 small cleanup 2022-06-20 08:59:51 -07:00
3 changed files with 14 additions and 8 deletions

View File

@@ -1703,7 +1703,7 @@ class LowresConditioner(nn.Module):
def __init__(
self,
downsample_first = True,
blur_sigma = (0.1, 0.2),
blur_sigma = 0.6,
blur_kernel_size = 3,
):
super().__init__()
@@ -1866,14 +1866,17 @@ class Decoder(nn.Module):
if not exists(beta_schedule):
beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))
beta_schedule = cast_tuple(beta_schedule, num_unets)
p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)
self.noise_schedulers = nn.ModuleList([])
for unet_beta_schedule in beta_schedule:
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule,
timesteps = timesteps,
loss_type = loss_type,
p2_loss_weight_gamma = p2_loss_weight_gamma,
p2_loss_weight_gamma = unet_p2_loss_weight_gamma,
p2_loss_weight_k = p2_loss_weight_k
)

View File

@@ -627,10 +627,13 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
results = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = results.pop(0)
for opt_ind in range(len(optimizers)):
setattr(self, f'optim{opt_ind}', results.pop(0))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = decoder
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
setattr(self, f'optim{opt_ind}', optimizer)
def save(self, path, overwrite = True, **kwargs):
path = Path(path)

View File

@@ -1 +1 @@
__version__ = '0.11.0'
__version__ = '0.11.2'