diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index f288b8e..2b5e2d3 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -473,7 +473,7 @@ class DecoderTrainer(nn.Module): lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps)) - assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4' + assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4' optimizers = [] schedulers = [] diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 617a906..bdc0bed 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.12' +__version__ = '0.16.13'