mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
complete ddim integration of diffusion prior as well as decoder for each unet, feature complete for https://github.com/lucidrains/DALLE2-pytorch/issues/157
This commit is contained in:
@@ -536,11 +536,19 @@ class DecoderTrainer(nn.Module):
|
||||
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
|
||||
clip = decoder.clip
|
||||
clip.to(precision_type)
|
||||
decoder, train_loader, val_loader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders["train"], dataloaders["val"], *optimizers))
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
# prepare dataloaders
|
||||
|
||||
train_loader = val_loader = None
|
||||
if exists(dataloaders):
|
||||
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.decoder = decoder
|
||||
|
||||
# store optimizers
|
||||
|
||||
|
||||
Reference in New Issue
Block a user