mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
cleanup and DRY a little
This commit is contained in:
@@ -1142,19 +1142,24 @@ class Decoder(nn.Module):
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
@contextmanager
|
||||
def one_unet_in_gpu(self, unet_number):
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
return self.unets[index]
|
||||
|
||||
@contextmanager
|
||||
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
||||
assert exists(unet_number) ^ exists(unet)
|
||||
|
||||
if exists(unet_number):
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
self.cuda()
|
||||
self.unets.cpu()
|
||||
|
||||
unet = self.unets[index]
|
||||
unet.cuda()
|
||||
|
||||
yield
|
||||
|
||||
self.unets.cpu()
|
||||
unet.cpu()
|
||||
|
||||
def get_text_encodings(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
@@ -1261,8 +1266,8 @@ class Decoder(nn.Module):
|
||||
|
||||
img = None
|
||||
|
||||
for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))):
|
||||
with self.one_unet_in_gpu(ind + 1):
|
||||
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
|
||||
with self.one_unet_in_gpu(unet = unet):
|
||||
shape = (batch_size, channels, image_size, image_size)
|
||||
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
||||
|
||||
@@ -1271,11 +1276,10 @@ class Decoder(nn.Module):
|
||||
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
assert 1 <= unet_number <= len(self.unets)
|
||||
|
||||
index = unet_number - 1
|
||||
unet = self.unets[index]
|
||||
target_image_size = self.image_sizes[index]
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
target_image_size = self.image_sizes[unet_number - 1]
|
||||
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
|
||||
@@ -1289,7 +1293,7 @@ class Decoder(nn.Module):
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||
|
||||
lowres_cond_img = image if index > 0 else None
|
||||
lowres_cond_img = image if unet_number > 1 else None
|
||||
ddpm_image = resize_image_to(image, target_image_size)
|
||||
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user