complete ddim integration of diffusion prior as well as decoder for each unet, feature complete for https://github.com/lucidrains/DALLE2-pytorch/issues/157

This commit is contained in:
Phil Wang
2022-07-09 17:25:34 -07:00
parent 68de937aac
commit b7e22f7da0
5 changed files with 186 additions and 34 deletions

View File

@@ -536,11 +536,19 @@ class DecoderTrainer(nn.Module):
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
clip = decoder.clip
clip.to(precision_type)
decoder, train_loader, val_loader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders["train"], dataloaders["val"], *optimizers))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = decoder
# prepare dataloaders
train_loader = val_loader = None
if exists(dataloaders):
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
self.train_loader = train_loader
self.val_loader = val_loader
self.decoder = decoder
# store optimizers