allow for setting beta schedules of unets differently in the decoder, as what was used in the paper was cosine, cosine, linear

This commit is contained in:
Phil Wang
2022-06-20 08:56:32 -07:00
parent f5a906f5d3
commit 138079ca83
4 changed files with 87 additions and 55 deletions

View File

@@ -24,7 +24,9 @@ def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
if exists(val):
return val
return d() if callable(d) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
@@ -574,8 +576,8 @@ def decoder_sample_in_chunks(fn):
class DecoderTrainer(nn.Module):
def __init__(
self,
accelerator,
decoder,
accelerator = None,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
@@ -589,7 +591,7 @@ class DecoderTrainer(nn.Module):
assert isinstance(decoder, Decoder)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.accelerator = accelerator
self.accelerator = default(accelerator, Accelerator)
self.num_unets = len(decoder.unets)