mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
move neural network creations off the configuration file into the pydantic classes
This commit is contained in:
@@ -85,20 +85,6 @@ def create_dataloaders(
|
||||
"test_sampling": test_sampling_dataloader
|
||||
}
|
||||
|
||||
|
||||
def create_decoder(device, decoder_config, unets_config):
|
||||
"""Creates a sample decoder"""
|
||||
|
||||
unets = [Unet(**config.dict()) for config in unets_config]
|
||||
|
||||
decoder = Decoder(
|
||||
unet=unets,
|
||||
**decoder_config.dict()
|
||||
)
|
||||
|
||||
decoder.to(device=device)
|
||||
return decoder
|
||||
|
||||
def get_dataset_keys(dataloader):
|
||||
"""
|
||||
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
||||
@@ -428,7 +414,7 @@ def initialize_training(config):
|
||||
**config.data.dict()
|
||||
)
|
||||
|
||||
decoder = create_decoder(device, config.decoder, config.unets)
|
||||
decoder = config.decoder.create().to(device = device)
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
print(print_ribbon("Loaded Config", repeat=40))
|
||||
print(f"Number of parameters: {num_parameters}")
|
||||
|
||||
Reference in New Issue
Block a user