cleanup and DRY a little

This commit is contained in:
Phil Wang
2022-04-20 10:56:22 -07:00
parent 27a33e1b20
commit f37c26e856
2 changed files with 18 additions and 14 deletions

View File

@@ -1142,19 +1142,24 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
@contextmanager def get_unet(self, unet_number):
def one_unet_in_gpu(self, unet_number):
assert 0 < unet_number <= len(self.unets) assert 0 < unet_number <= len(self.unets)
index = unet_number - 1 index = unet_number - 1
return self.unets[index]
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet)
if exists(unet_number):
unet = self.get_unet(unet_number)
self.cuda() self.cuda()
self.unets.cpu() self.unets.cpu()
unet = self.unets[index]
unet.cuda() unet.cuda()
yield yield
unet.cpu()
self.unets.cpu()
def get_text_encodings(self, text): def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text) text_encodings = self.clip.text_transformer(text)
@@ -1261,8 +1266,8 @@ class Decoder(nn.Module):
img = None img = None
for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))): for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
with self.one_unet_in_gpu(ind + 1): with self.one_unet_in_gpu(unet = unet):
shape = (batch_size, channels, image_size, image_size) shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img) img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
@@ -1271,11 +1276,10 @@ class Decoder(nn.Module):
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None): def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1) unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
index = unet_number - 1 unet = self.get_unet(unet_number)
unet = self.unets[index]
target_image_size = self.image_sizes[index] target_image_size = self.image_sizes[unet_number - 1]
b, c, h, w, device, = *image.shape, image.device b, c, h, w, device, = *image.shape, image.device
@@ -1289,7 +1293,7 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = image if index > 0 else None lowres_cond_img = image if unet_number > 1 else None
ddpm_image = resize_image_to(image, target_image_size) ddpm_image = resize_image_to(image, target_image_size)
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img) return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.28', version = '0.0.30',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',