mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
complete contextmanager method for keeping only one unet in GPU during training or inference
This commit is contained in:
@@ -411,8 +411,8 @@ Offer training wrappers
|
|||||||
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
||||||
- [x] add efficient attention in unet
|
- [x] add efficient attention in unet
|
||||||
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
||||||
|
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||||
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
||||||
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import math
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -1141,6 +1142,20 @@ class Decoder(nn.Module):
|
|||||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def one_unet_in_gpu(self, unet_number):
|
||||||
|
assert 0 < unet_number <= len(self.unets)
|
||||||
|
index = unet_number - 1
|
||||||
|
self.cuda()
|
||||||
|
self.unets.cpu()
|
||||||
|
|
||||||
|
unet = self.unets[index]
|
||||||
|
unet.cuda()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
self.unets.cpu()
|
||||||
|
|
||||||
def get_text_encodings(self, text):
|
def get_text_encodings(self, text):
|
||||||
text_encodings = self.clip.text_transformer(text)
|
text_encodings = self.clip.text_transformer(text)
|
||||||
return text_encodings[:, 1:]
|
return text_encodings[:, 1:]
|
||||||
@@ -1245,9 +1260,11 @@ class Decoder(nn.Module):
|
|||||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||||
|
|
||||||
img = None
|
img = None
|
||||||
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
|
|
||||||
shape = (batch_size, channels, image_size, image_size)
|
for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))):
|
||||||
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
with self.one_unet_in_gpu(ind + 1):
|
||||||
|
shape = (batch_size, channels, image_size, image_size)
|
||||||
|
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user