mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
accept unets as list for decoder
This commit is contained in:
@@ -59,6 +59,9 @@ def default(val, d):
|
|||||||
return d() if isfunction(d) else d
|
return d() if isfunction(d) else d
|
||||||
|
|
||||||
def cast_tuple(val, length = 1):
|
def cast_tuple(val, length = 1):
|
||||||
|
if isinstance(val, list):
|
||||||
|
val = tuple(val)
|
||||||
|
|
||||||
return val if isinstance(val, tuple) else ((val,) * length)
|
return val if isinstance(val, tuple) else ((val,) * length)
|
||||||
|
|
||||||
def module_device(module):
|
def module_device(module):
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.3.6',
|
version = '0.3.7',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ def create_decoder(device, decoder_config, unets_config):
|
|||||||
))
|
))
|
||||||
|
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
unet=tuple(unets), # Must be tuple because of cast_tuple
|
unet=unets,
|
||||||
**decoder_config
|
**decoder_config
|
||||||
)
|
)
|
||||||
decoder.to(device=device)
|
decoder.to(device=device)
|
||||||
|
|||||||
Reference in New Issue
Block a user