Compare commits

...

3 Commits

Author SHA1 Message Date
Phil Wang
fb662a62f3 fix another bug thanks to @xiankgx 2022-04-29 07:38:32 -07:00
Phil Wang
587c8c9b44 optimize for clarity 2022-04-28 21:59:13 -07:00
Phil Wang
aa900213e7 force first unet in the cascade to be conditioned on image embeds 2022-04-28 20:53:15 -07:00
2 changed files with 30 additions and 33 deletions

View File

@@ -223,7 +223,18 @@ class BaseGaussianDiffusion(nn.Module):
timesteps, = betas.shape timesteps, = betas.shape
self.num_timesteps = int(timesteps) self.num_timesteps = int(timesteps)
if loss_type == 'l1':
loss_fn = F.l1_loss
elif loss_type == 'l2':
loss_fn = F.mse_loss
elif loss_type == 'huber':
loss_fn = F.smooth_l1_loss
else:
raise NotImplementedError()
self.loss_type = loss_type self.loss_type = loss_type
self.loss_fn = loss_fn
self.register_buffer('betas', betas) self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod) self.register_buffer('alphas_cumprod', alphas_cumprod)
@@ -703,29 +714,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond) img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img return img
def p_losses(self, image_embed, t, text_cond, noise = None): def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed)) noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise) image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
x_recon = self.net( pred = self.net(
image_embed_noisy, image_embed_noisy,
t, times,
cond_drop_prob = self.cond_drop_prob, cond_drop_prob = self.cond_drop_prob,
**text_cond **text_cond
) )
to_predict = noise if not self.predict_x_start else image_embed target = noise if not self.predict_x_start else image_embed
if self.loss_type == 'l1':
loss = F.l1_loss(to_predict, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(to_predict, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(to_predict, x_recon)
else:
raise NotImplementedError()
loss = self.loss_fn(pred, target)
return loss return loss
@torch.no_grad() @torch.no_grad()
@@ -1066,13 +1069,14 @@ class Unet(nn.Module):
self, self,
*, *,
lowres_cond, lowres_cond,
channels channels,
cond_on_image_embeds
): ):
if lowres_cond == self.lowres_cond and channels == self.channels: if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
return self return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels} updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
return self.__class__(**updated_kwargs) return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
@@ -1209,7 +1213,7 @@ class LowresConditioner(nn.Module):
target_image_size = cast_tuple(target_image_size, 2) target_image_size = cast_tuple(target_image_size, 2)
if self.training and self.downsample_first and exists(downsample_image_size): if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode) cond_fmap = resize_image_to(cond_fmap, downsample_image_size, mode = self.cond_upsample_mode)
if self.training: if self.training:
# when training, blur the low resolution conditional image # when training, blur the low resolution conditional image
@@ -1279,6 +1283,7 @@ class Decoder(BaseGaussianDiffusion):
one_unet = one_unet.cast_model_parameters( one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first, lowres_cond = not is_first,
cond_on_image_embeds = is_first,
channels = unet_channels channels = unet_channels
) )
@@ -1386,14 +1391,14 @@ class Decoder(BaseGaussianDiffusion):
return img return img
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None): def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
x_recon = unet( pred = unet(
x_noisy, x_noisy,
t, times,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
@@ -1402,15 +1407,7 @@ class Decoder(BaseGaussianDiffusion):
target = noise if not predict_x_start else x_start target = noise if not predict_x_start else x_start
if self.loss_type == 'l1': loss = self.loss_fn(pred, target)
loss = F.l1_loss(target, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(target, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(target, x_recon)
else:
raise NotImplementedError()
return loss return loss
@torch.no_grad() @torch.no_grad()

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.63', version = '0.0.65',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',