mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
accept unets as list for decoder
This commit is contained in:
@@ -59,6 +59,9 @@ def default(val, d):
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def module_device(module):
|
||||
|
||||
Reference in New Issue
Block a user