mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
cleanup and DRY a little
This commit is contained in:
@@ -1142,19 +1142,24 @@ class Decoder(nn.Module):
|
|||||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||||
|
|
||||||
@contextmanager
|
def get_unet(self, unet_number):
|
||||||
def one_unet_in_gpu(self, unet_number):
|
|
||||||
assert 0 < unet_number <= len(self.unets)
|
assert 0 < unet_number <= len(self.unets)
|
||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
|
return self.unets[index]
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
||||||
|
assert exists(unet_number) ^ exists(unet)
|
||||||
|
|
||||||
|
if exists(unet_number):
|
||||||
|
unet = self.get_unet(unet_number)
|
||||||
|
|
||||||
self.cuda()
|
self.cuda()
|
||||||
self.unets.cpu()
|
self.unets.cpu()
|
||||||
|
|
||||||
unet = self.unets[index]
|
|
||||||
unet.cuda()
|
unet.cuda()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
unet.cpu()
|
||||||
self.unets.cpu()
|
|
||||||
|
|
||||||
def get_text_encodings(self, text):
|
def get_text_encodings(self, text):
|
||||||
text_encodings = self.clip.text_transformer(text)
|
text_encodings = self.clip.text_transformer(text)
|
||||||
@@ -1261,8 +1266,8 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
img = None
|
img = None
|
||||||
|
|
||||||
for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))):
|
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
|
||||||
with self.one_unet_in_gpu(ind + 1):
|
with self.one_unet_in_gpu(unet = unet):
|
||||||
shape = (batch_size, channels, image_size, image_size)
|
shape = (batch_size, channels, image_size, image_size)
|
||||||
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
||||||
|
|
||||||
@@ -1271,11 +1276,10 @@ class Decoder(nn.Module):
|
|||||||
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
|
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
|
||||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||||
unet_number = default(unet_number, 1)
|
unet_number = default(unet_number, 1)
|
||||||
assert 1 <= unet_number <= len(self.unets)
|
|
||||||
|
|
||||||
index = unet_number - 1
|
unet = self.get_unet(unet_number)
|
||||||
unet = self.unets[index]
|
|
||||||
target_image_size = self.image_sizes[index]
|
target_image_size = self.image_sizes[unet_number - 1]
|
||||||
|
|
||||||
b, c, h, w, device, = *image.shape, image.device
|
b, c, h, w, device, = *image.shape, image.device
|
||||||
|
|
||||||
@@ -1289,7 +1293,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||||
|
|
||||||
lowres_cond_img = image if index > 0 else None
|
lowres_cond_img = image if unet_number > 1 else None
|
||||||
ddpm_image = resize_image_to(image, target_image_size)
|
ddpm_image = resize_image_to(image, target_image_size)
|
||||||
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user