mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
accept unets as list for decoder
This commit is contained in:
@@ -59,6 +59,9 @@ def default(val, d):
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def module_device(module):
|
||||
|
||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.3.6',
|
||||
version = '0.3.7',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -86,7 +86,7 @@ def create_decoder(device, decoder_config, unets_config):
|
||||
))
|
||||
|
||||
decoder = Decoder(
|
||||
unet=tuple(unets), # Must be tuple because of cast_tuple
|
||||
unet=unets,
|
||||
**decoder_config
|
||||
)
|
||||
decoder.to(device=device)
|
||||
|
||||
Reference in New Issue
Block a user