accept unets as list for decoder

This commit is contained in:
Phil Wang
2022-05-20 20:31:26 -07:00
parent f526f14d7c
commit 80497e9839
3 changed files with 5 additions and 2 deletions

View File

@@ -59,6 +59,9 @@ def default(val, d):
return d() if isfunction(d) else d
def cast_tuple(val, length = 1):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else ((val,) * length)
def module_device(module):

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.3.6',
version = '0.3.7',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',

View File

@@ -86,7 +86,7 @@ def create_decoder(device, decoder_config, unets_config):
))
decoder = Decoder(
unet=tuple(unets), # Must be tuple because of cast_tuple
unet=unets,
**decoder_config
)
decoder.to(device=device)