mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 22:24:22 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f545ce18f4 | ||
|
|
fc7abf624d | ||
|
|
67f0740777 |
@@ -1703,7 +1703,7 @@ class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
downsample_first = True,
|
||||
blur_sigma = (0.1, 0.2),
|
||||
blur_sigma = 0.6,
|
||||
blur_kernel_size = 3,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1866,14 +1866,17 @@ class Decoder(nn.Module):
|
||||
if not exists(beta_schedule):
|
||||
beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))
|
||||
|
||||
beta_schedule = cast_tuple(beta_schedule, num_unets)
|
||||
p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)
|
||||
|
||||
self.noise_schedulers = nn.ModuleList([])
|
||||
|
||||
for unet_beta_schedule in beta_schedule:
|
||||
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
|
||||
noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = unet_beta_schedule,
|
||||
timesteps = timesteps,
|
||||
loss_type = loss_type,
|
||||
p2_loss_weight_gamma = p2_loss_weight_gamma,
|
||||
p2_loss_weight_gamma = unet_p2_loss_weight_gamma,
|
||||
p2_loss_weight_k = p2_loss_weight_k
|
||||
)
|
||||
|
||||
|
||||
@@ -627,10 +627,13 @@ class DecoderTrainer(nn.Module):
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
results = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
self.decoder = results.pop(0)
|
||||
for opt_ind in range(len(optimizers)):
|
||||
setattr(self, f'optim{opt_ind}', results.pop(0))
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
||||
setattr(self, f'optim{opt_ind}', optimizer)
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.11.0'
|
||||
__version__ = '0.11.2'
|
||||
|
||||
Reference in New Issue
Block a user