diff --git a/README.md b/README.md index afae8af..2ffb034 100644 --- a/README.md +++ b/README.md @@ -411,8 +411,8 @@ Offer training wrappers - [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions - [x] add efficient attention in unet - [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning) +- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) - [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting) -- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] train on a toy task, offer in colab diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 00a3032..4bb41cf 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -2,6 +2,7 @@ import math from tqdm import tqdm from inspect import isfunction from functools import partial +from contextlib import contextmanager import torch import torch.nn.functional as F @@ -1141,6 +1142,20 @@ class Decoder(nn.Module): self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + @contextmanager + def one_unet_in_gpu(self, unet_number): + assert 0 < unet_number <= len(self.unets) + index = unet_number - 1 + self.cuda() + self.unets.cpu() + + unet = self.unets[index] + unet.cuda() + + yield + + self.unets.cpu() + def get_text_encodings(self, text): text_encodings = self.clip.text_transformer(text) return text_encodings[:, 1:] @@ -1245,9 +1260,11 @@ class Decoder(nn.Module): text_encodings = self.get_text_encodings(text) if exists(text) else None img = None - for unet, image_size in tqdm(zip(self.unets, self.image_sizes)): - shape = (batch_size, channels, image_size, image_size) - img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img) + + for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))): + with self.one_unet_in_gpu(ind + 1): + shape = (batch_size, channels, image_size, image_size) + img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img) return img diff --git a/setup.py b/setup.py index 41e80b0..e6d66c9 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.27', + version = '0.0.28', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',