mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 22:34:21 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b994601ae | ||
|
|
fddf66e91e |
@@ -451,6 +451,8 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
|
||||
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
optimizers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.11.4'
|
||||
__version__ = '0.11.5'
|
||||
|
||||
@@ -258,8 +258,8 @@ def train(
|
||||
is_master = accelerator.process_index == 0
|
||||
|
||||
trainer = DecoderTrainer(
|
||||
accelerator,
|
||||
decoder,
|
||||
decoder=decoder,
|
||||
accelerator=accelerator,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user