diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 91f8f42..f932eab 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -59,6 +59,9 @@ def default(val, d): return d() if isfunction(d) else d def cast_tuple(val, length = 1): + if isinstance(val, list): + val = tuple(val) + return val if isinstance(val, tuple) else ((val,) * length) def module_device(module): diff --git a/setup.py b/setup.py index 64cce42..ebfa598 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.3.6', + version = '0.3.7', license='MIT', description = 'DALL-E 2', author = 'Phil Wang', diff --git a/train_decoder.py b/train_decoder.py index e3daa5e..f179c01 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -86,7 +86,7 @@ def create_decoder(device, decoder_config, unets_config): )) decoder = Decoder( - unet=tuple(unets), # Must be tuple because of cast_tuple + unet=unets, **decoder_config ) decoder.to(device=device)