accept unets as list for decoder

This commit is contained in:
Phil Wang
2022-05-20 20:31:26 -07:00
parent f526f14d7c
commit 80497e9839
3 changed files with 5 additions and 2 deletions

View File

@@ -86,7 +86,7 @@ def create_decoder(device, decoder_config, unets_config):
))
decoder = Decoder(
unet=tuple(unets), # Must be tuple because of cast_tuple
unet=unets,
**decoder_config
)
decoder.to(device=device)