Compare commits

..

13 Commits

Author SHA1 Message Date
Phil Wang
e0835acca9 generate text mask within the unet and diffusion prior itself from the text encodings, if not given 2022-07-12 12:54:59 -07:00
Phil Wang
e055793e5d shoutout for @MalumaDev 2022-07-11 16:12:35 -07:00
Phil Wang
1d9ef99288 add PixelShuffleUpsample thanks to @MalumaDev and @marunine for running the experiment and verifyng absence of checkboard artifacts 2022-07-11 16:07:23 -07:00
Phil Wang
bdd62c24b3 zero init final projection in unet, since openai and @crowsonkb are both doing it 2022-07-11 13:22:06 -07:00
Phil Wang
1f1557c614 make it so even if text mask is omitted, it will be derived based on whether text encodings are all 0s or not, simplify dataloading 2022-07-11 10:56:19 -07:00
Aidan Dempster
1a217e99e3 Unet parameter count is now shown (#202) 2022-07-10 16:45:59 -07:00
Phil Wang
7ea314e2f0 allow for final l2norm clamping of the sampled image embed 2022-07-10 09:44:38 -07:00
Phil Wang
4173e88121 more accurate readme 2022-07-09 20:57:26 -07:00
Phil Wang
3dae43fa0e fix misnamed variable, thanks to @nousr 2022-07-09 19:01:37 -07:00
Phil Wang
a598820012 do not noise for the last step in ddim 2022-07-09 18:38:40 -07:00
Phil Wang
4878762627 fix for small validation bug for sampling steps 2022-07-09 17:31:54 -07:00
Phil Wang
47ae17b36e more informative error for something that tripped me up 2022-07-09 17:28:14 -07:00
Phil Wang
b7e22f7da0 complete ddim integration of diffusion prior as well as decoder for each unet, feature complete for https://github.com/lucidrains/DALLE2-pytorch/issues/157 2022-07-09 17:25:34 -07:00
5 changed files with 89 additions and 34 deletions

View File

@@ -45,6 +45,7 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
- <a href="https://github.com/malumadev">MalumaDev</a> for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
@@ -355,7 +356,8 @@ prior_network = DiffusionPriorNetwork(
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
timesteps = 1000,
sample_timesteps = 64,
cond_drop_prob = 0.2
).cuda()

View File

@@ -77,6 +77,11 @@ def cast_tuple(val, length = None):
def module_device(module):
return next(module.parameters()).device
def zero_init_(m):
nn.init.zeros_(m.weight)
if exists(m.bias):
nn.init.zeros_(m.bias)
@contextmanager
def null_context(*args, **kwargs):
yield
@@ -220,6 +225,7 @@ class XClipAdapter(BaseClipAdapter):
encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
@torch.no_grad()
@@ -255,6 +261,7 @@ class CoCaAdapter(BaseClipAdapter):
text = text[..., :self.max_text_len]
text_mask = text != 0
text_embed, text_encodings = self.clip.embed_text(text)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(text_embed, text_encodings, text_mask)
@torch.no_grad()
@@ -314,6 +321,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
del self.text_encodings
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
@@ -864,7 +872,7 @@ class DiffusionPriorNetwork(nn.Module):
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
if not exists(mask):
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
mask = torch.any(text_encodings != 0., dim = -1)
# classifier free guidance
@@ -922,11 +930,12 @@ class DiffusionPrior(nn.Module):
loss_type = "l2",
predict_x_start = True,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False,
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
training_clamp_l2norm = False,
init_image_embed_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
super().__init__()
@@ -963,23 +972,32 @@ class DiffusionPrior(nn.Module):
self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
self.predict_x_start = predict_x_start
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
# whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm
self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
# device tracker
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
@property
def device(self):
return self._dummy.device
def l2norm_clamp_embed(self, image_embed):
return l2norm(image_embed) * self.image_embed_scale
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -1020,6 +1038,9 @@ class DiffusionPrior(nn.Module):
times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed)
return image_embed
@torch.no_grad()
@@ -1055,15 +1076,18 @@ class DiffusionPrior(nn.Module):
x_start.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm:
x_start = l2norm(x_start) * self.image_embed_scale
x_start = self.l2norm_clamp_embed(x_start)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
new_noise = torch.randn_like(image_embed)
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
img = x_start * alpha_next.sqrt() + \
c1 * new_noise + \
c2 * pred_noise
image_embed = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
if self.predict_x_start and self.sampling_final_clamp_l2norm:
image_embed = self.l2norm_clamp_embed(image_embed)
return image_embed
@@ -1091,7 +1115,7 @@ class DiffusionPrior(nn.Module):
)
if self.predict_x_start and self.training_clamp_l2norm:
pred = l2norm(pred) * self.image_embed_scale
pred = self.l2norm_clamp_embed(pred)
target = noise if not self.predict_x_start else image_embed
@@ -1198,16 +1222,35 @@ class DiffusionPrior(nn.Module):
# decoder
def ConvTransposeUpsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
class PixelShuffleUpsample(nn.Module):
"""
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
"""
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
def NearestUpsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, dim_out, 3, padding = 1)
)
self.net = nn.Sequential(
conv,
nn.SiLU(),
nn.PixelShuffle(2)
)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)
def Downsample(dim, *, dim_out = None):
dim_out = default(dim_out, dim)
@@ -1471,7 +1514,7 @@ class Unet(nn.Module):
cross_embed_downsample_kernel_sizes = (2, 4),
memory_efficient = False,
scale_skip_connection = False,
nearest_upsample = False,
pixel_shuffle_upsample = True,
final_conv_kernel_size = 1,
**kwargs
):
@@ -1537,10 +1580,12 @@ class Unet(nn.Module):
# text encoding conditioning (optional)
self.text_to_cond = None
self.text_embed_dim = None
if cond_on_text_encodings:
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
self.text_embed_dim = text_embed_dim
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
@@ -1583,7 +1628,7 @@ class Unet(nn.Module):
# upsample klass
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# give memory efficient unet an initial resnet block
@@ -1647,6 +1692,8 @@ class Unet(nn.Module):
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters(
@@ -1769,6 +1816,11 @@ class Unet(nn.Module):
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
if not exists(text_mask):
text_mask = torch.any(text_encodings != 0., dim = -1)
text_tokens = self.text_to_cond(text_encodings)
text_tokens = text_tokens[:, :self.max_text_len]
@@ -1777,13 +1829,10 @@ class Unet(nn.Module):
if remainder > 0:
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
text_mask = F.pad(text_mask, (0, remainder), value = False)
if exists(text_mask):
if remainder > 0:
text_mask = F.pad(text_mask, (0, remainder), value = False)
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask = text_mask & text_keep_mask
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask = text_mask & text_keep_mask
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
@@ -2043,7 +2092,7 @@ class Decoder(nn.Module):
self.noise_schedulers = nn.ModuleList([])
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
assert sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule,
@@ -2271,9 +2320,10 @@ class Decoder(nn.Module):
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if time_next > 0 else 0.
img = x_start * alpha_next.sqrt() + \
c1 * torch.randn_like(img) + \
c1 * noise + \
c2 * pred_noise
img = self.unnormalize_img(img)

View File

@@ -234,7 +234,7 @@ class DecoderConfig(BaseModel):
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sampling_timesteps: Optional[SingularOrIterable(int)] = None
sample_timesteps: Optional[SingularOrIterable(int)] = None
loss_type: str = 'l2'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True

View File

@@ -1 +1 @@
__version__ = '0.19.0'
__version__ = '0.21.1'

View File

@@ -557,7 +557,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Create the decoder model and print basic info
decoder = config.decoder.create()
num_parameters = sum(p.numel() for p in decoder.parameters())
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
@@ -586,7 +586,10 @@ def initialize_training(config: TrainDecoderConfig, config_path):
accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
accelerator.print(f"Number of parameters: {num_parameters}")
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
for i, unet in enumerate(decoder.unets):
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=accelerator.device,