complete contextmanager method for keeping only one unet in GPU during training or inference

This commit is contained in:
Phil Wang
2022-04-20 10:46:13 -07:00
parent 6f941a219a
commit 27a33e1b20
3 changed files with 22 additions and 5 deletions

View File

@@ -411,8 +411,8 @@ Offer training wrappers
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [x] add efficient attention in unet
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab

View File

@@ -2,6 +2,7 @@ import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
from contextlib import contextmanager
import torch
import torch.nn.functional as F
@@ -1141,6 +1142,20 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
@contextmanager
def one_unet_in_gpu(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
self.cuda()
self.unets.cpu()
unet = self.unets[index]
unet.cuda()
yield
self.unets.cpu()
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
@@ -1245,7 +1260,9 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) else None
img = None
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))):
with self.one_unet_in_gpu(ind + 1):
shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.27',
version = '0.0.28',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',