Compare commits

...

5 Commits

Author SHA1 Message Date
Phil Wang
1f1557c614 make it so even if text mask is omitted, it will be derived based on whether text encodings are all 0s or not, simplify dataloading 2022-07-11 10:56:19 -07:00
Aidan Dempster
1a217e99e3 Unet parameter count is now shown (#202) 2022-07-10 16:45:59 -07:00
Phil Wang
7ea314e2f0 allow for final l2norm clamping of the sampled image embed 2022-07-10 09:44:38 -07:00
Phil Wang
4173e88121 more accurate readme 2022-07-09 20:57:26 -07:00
Phil Wang
3dae43fa0e fix misnamed variable, thanks to @nousr 2022-07-09 19:01:37 -07:00
4 changed files with 42 additions and 12 deletions

View File

@@ -355,7 +355,8 @@ prior_network = DiffusionPriorNetwork(
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
timesteps = 1000,
sample_timesteps = 64,
cond_drop_prob = 0.2
).cuda()

View File

@@ -220,6 +220,7 @@ class XClipAdapter(BaseClipAdapter):
encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
@torch.no_grad()
@@ -255,6 +256,7 @@ class CoCaAdapter(BaseClipAdapter):
text = text[..., :self.max_text_len]
text_mask = text != 0
text_embed, text_encodings = self.clip.embed_text(text)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(text_embed, text_encodings, text_mask)
@torch.no_grad()
@@ -314,6 +316,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
del self.text_encodings
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
@@ -922,11 +925,12 @@ class DiffusionPrior(nn.Module):
loss_type = "l2",
predict_x_start = True,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False,
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
training_clamp_l2norm = False,
init_image_embed_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
super().__init__()
@@ -963,23 +967,32 @@ class DiffusionPrior(nn.Module):
self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
self.predict_x_start = predict_x_start
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
# whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm
self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
# device tracker
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
@property
def device(self):
return self._dummy.device
def l2norm_clamp_embed(self, image_embed):
return l2norm(image_embed) * self.image_embed_scale
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -1020,6 +1033,9 @@ class DiffusionPrior(nn.Module):
times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed)
return image_embed
@torch.no_grad()
@@ -1055,15 +1071,18 @@ class DiffusionPrior(nn.Module):
x_start.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm:
x_start = l2norm(x_start) * self.image_embed_scale
x_start = self.l2norm_clamp_embed(x_start)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
img = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
image_embed = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
if self.predict_x_start and self.sampling_final_clamp_l2norm:
image_embed = self.l2norm_clamp_embed(image_embed)
return image_embed
@@ -1091,7 +1110,7 @@ class DiffusionPrior(nn.Module):
)
if self.predict_x_start and self.training_clamp_l2norm:
pred = l2norm(pred) * self.image_embed_scale
pred = self.l2norm_clamp_embed(pred)
target = noise if not self.predict_x_start else image_embed
@@ -1181,6 +1200,7 @@ class DiffusionPrior(nn.Module):
if self.condition_on_text_encodings:
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
# timestep conditioning from ddpm
@@ -2394,6 +2414,9 @@ class Decoder(nn.Module):
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
if self.condition_on_text_encodings:
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
img = None
is_cuda = next(self.parameters()).is_cuda
@@ -2477,6 +2500,9 @@ class Decoder(nn.Module):
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
if self.condition_on_text_encodings:
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)

View File

@@ -1 +1 @@
__version__ = '0.19.4'
__version__ = '0.20.0'

View File

@@ -557,7 +557,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Create the decoder model and print basic info
decoder = config.decoder.create()
num_parameters = sum(p.numel() for p in decoder.parameters())
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
@@ -586,7 +586,10 @@ def initialize_training(config: TrainDecoderConfig, config_path):
accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
accelerator.print(f"Number of parameters: {num_parameters}")
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
for i, unet in enumerate(decoder.unets):
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=accelerator.device,