From f37c26e85619161c86afec22dfc4879764c0b710 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 20 Apr 2022 10:56:22 -0700 Subject: [PATCH] cleanup and DRY a little --- dalle2_pytorch/dalle2_pytorch.py | 30 +++++++++++++++++------------- setup.py | 2 +- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 4bb41cf..7c2423b 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1142,19 +1142,24 @@ class Decoder(nn.Module): self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) - @contextmanager - def one_unet_in_gpu(self, unet_number): + def get_unet(self, unet_number): assert 0 < unet_number <= len(self.unets) index = unet_number - 1 + return self.unets[index] + + @contextmanager + def one_unet_in_gpu(self, unet_number = None, unet = None): + assert exists(unet_number) ^ exists(unet) + + if exists(unet_number): + unet = self.get_unet(unet_number) + self.cuda() self.unets.cpu() - unet = self.unets[index] unet.cuda() - yield - - self.unets.cpu() + unet.cpu() def get_text_encodings(self, text): text_encodings = self.clip.text_transformer(text) @@ -1261,8 +1266,8 @@ class Decoder(nn.Module): img = None - for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))): - with self.one_unet_in_gpu(ind + 1): + for unet, image_size in tqdm(zip(self.unets, self.image_sizes)): + with self.one_unet_in_gpu(unet = unet): shape = (batch_size, channels, image_size, image_size) img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img) @@ -1271,11 +1276,10 @@ class Decoder(nn.Module): def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None): assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' unet_number = default(unet_number, 1) - assert 1 <= unet_number <= len(self.unets) - index = unet_number - 1 - unet = self.unets[index] - target_image_size = self.image_sizes[index] + unet = self.get_unet(unet_number) + + target_image_size = self.image_sizes[unet_number - 1] b, c, h, w, device, = *image.shape, image.device @@ -1289,7 +1293,7 @@ class Decoder(nn.Module): text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None - lowres_cond_img = image if index > 0 else None + lowres_cond_img = image if unet_number > 1 else None ddpm_image = resize_image_to(image, target_image_size) return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img) diff --git a/setup.py b/setup.py index e6d66c9..e9db6d0 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.28', + version = '0.0.30', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',