mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
allow for setting beta schedules of unets differently in the decoder, as what was used in the paper was cosine, cosine, linear
This commit is contained in:
@@ -24,7 +24,9 @@ def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if callable(d) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
@@ -574,8 +576,8 @@ def decoder_sample_in_chunks(fn):
|
||||
class DecoderTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
accelerator,
|
||||
decoder,
|
||||
accelerator = None,
|
||||
use_ema = True,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
@@ -589,7 +591,7 @@ class DecoderTrainer(nn.Module):
|
||||
assert isinstance(decoder, Decoder)
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
self.accelerator = accelerator
|
||||
self.accelerator = default(accelerator, Accelerator)
|
||||
|
||||
self.num_unets = len(decoder.unets)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user