mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
accept unets as list for decoder
This commit is contained in:
@@ -86,7 +86,7 @@ def create_decoder(device, decoder_config, unets_config):
|
||||
))
|
||||
|
||||
decoder = Decoder(
|
||||
unet=tuple(unets), # Must be tuple because of cast_tuple
|
||||
unet=unets,
|
||||
**decoder_config
|
||||
)
|
||||
decoder.to(device=device)
|
||||
|
||||
Reference in New Issue
Block a user