Compare commits

...

10 Commits

Author SHA1 Message Date
Phil Wang
b2073219f0 foolproof sampling for decoder to always use eval mode (and restore training state afterwards) 2022-07-13 10:21:00 -07:00
Phil Wang
cc0f7a935c fix non pixel shuffle upsample 2022-07-13 10:16:02 -07:00
Phil Wang
95a512cb65 fix a potential bug with conditioning with blurred low resolution image, blur should be applied only 50% of the time 2022-07-13 10:11:49 -07:00
Phil Wang
972ee973bc fix issue with ddim and normalization of lowres conditioning image 2022-07-13 09:48:40 -07:00
Phil Wang
79e2a3bc77 only use the stable layernorm for final output norm in transformer 2022-07-13 07:56:30 -07:00
Aidan Dempster
544cdd0b29 Reverted to using basic dataloaders (#205)
Accelerate removes the ability to collate strings. Likely since it
cannot gather strings.
2022-07-12 18:22:27 -07:00
Phil Wang
349aaca56f add yet another transformer stability measure 2022-07-12 17:49:16 -07:00
Phil Wang
3ee3c56d2a add learned padding tokens, same strategy as dalle1, for diffusion prior, and get rid of masking in causal transformer 2022-07-12 17:33:14 -07:00
Phil Wang
cd26c6b17d 0.22.3 2022-07-12 17:08:31 -07:00
Phil Wang
775abc4df6 add setting to attend to all text encodings regardless of padding, for diffusion prior 2022-07-12 17:08:12 -07:00
5 changed files with 89 additions and 24 deletions

View File

@@ -527,25 +527,31 @@ class NoiseScheduler(nn.Module):
# diffusion prior
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
def __init__(self, dim, eps = 1e-5, stable = False):
super().__init__()
self.eps = eps
self.stable = stable
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x / x.amax(dim = -1, keepdim = True).detach()
if self.stable:
x = x / x.amax(dim = -1, keepdim = True).detach()
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
def __init__(self, dim, eps = 1e-5, stable = False):
super().__init__()
self.eps = eps
self.stable = stable
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
x = x / x.amax(dim = 1, keepdim = True).detach()
if self.stable:
x = x / x.amax(dim = 1, keepdim = True).detach()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
@@ -669,7 +675,7 @@ class Attention(nn.Module):
dropout = 0.,
causal = False,
rotary_emb = None,
pb_relax_alpha = 32 ** 2
pb_relax_alpha = 128
):
super().__init__()
self.pb_relax_alpha = pb_relax_alpha
@@ -760,6 +766,7 @@ class CausalTransformer(nn.Module):
dim_head = 64,
heads = 8,
ff_mult = 4,
norm_in = False,
norm_out = True,
attn_dropout = 0.,
ff_dropout = 0.,
@@ -768,6 +775,8 @@ class CausalTransformer(nn.Module):
rotary_emb = True
):
super().__init__()
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
self.rel_pos_bias = RelPosBias(heads = heads)
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
@@ -779,20 +788,18 @@ class CausalTransformer(nn.Module):
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
def forward(
self,
x,
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
):
def forward(self, x):
n, device = x.shape[1], x.device
x = self.init_norm(x)
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
for attn, ff in self.layers:
x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = attn(x, attn_bias = attn_bias) + x
x = ff(x) + x
out = self.norm(x)
@@ -806,6 +813,7 @@ class DiffusionPriorNetwork(nn.Module):
num_time_embeds = 1,
num_image_embeds = 1,
num_text_embeds = 1,
max_text_len = 256,
**kwargs
):
super().__init__()
@@ -831,6 +839,11 @@ class DiffusionPriorNetwork(nn.Module):
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
# dalle1 learned padding strategy
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
def forward_with_cond_scale(
self,
*args,
@@ -852,7 +865,6 @@ class DiffusionPriorNetwork(nn.Module):
*,
text_embed,
text_encodings = None,
mask = None,
cond_drop_prob = 0.
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
@@ -870,9 +882,29 @@ class DiffusionPriorNetwork(nn.Module):
if not exists(text_encodings):
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
mask = torch.any(text_encodings != 0., dim = -1)
# replace any padding in the text encodings with learned padding tokens unique across position
text_encodings = text_encodings[:, :self.max_text_len]
mask = mask[:, :self.max_text_len]
text_len = text_encodings.shape[-2]
remainder = self.max_text_len - text_len
if remainder > 0:
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
mask = F.pad(mask, (0, remainder), value = False)
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
text_encodings = torch.where(
rearrange(mask, 'b n -> b n 1'),
text_encodings,
null_text_embeds
)
# classifier free guidance
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
@@ -905,7 +937,7 @@ class DiffusionPriorNetwork(nn.Module):
# attend
tokens = self.causal_transformer(tokens, mask = mask)
tokens = self.causal_transformer(tokens)
# get learned query, which should predict the image embedding (per DDPM timestep)
@@ -1219,6 +1251,14 @@ class DiffusionPrior(nn.Module):
# decoder
def NearestUpsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, dim_out, 3, padding = 1)
)
class PixelShuffleUpsample(nn.Module):
"""
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
@@ -1625,7 +1665,7 @@ class Unet(nn.Module):
# upsample klass
upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# give memory efficient unet an initial resnet block
@@ -1914,6 +1954,7 @@ class LowresConditioner(nn.Module):
self,
downsample_first = True,
downsample_mode_nearest = False,
blur_prob = 0.5,
blur_sigma = 0.6,
blur_kernel_size = 3,
input_image_range = None
@@ -1924,6 +1965,7 @@ class LowresConditioner(nn.Module):
self.input_image_range = input_image_range
self.blur_prob = blur_prob
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
@@ -1936,20 +1978,27 @@ class LowresConditioner(nn.Module):
blur_sigma = None,
blur_kernel_size = None
):
if self.training and self.downsample_first and exists(downsample_image_size):
if self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest)
if self.training:
# blur is only applied 50% of the time
# section 3.1 in https://arxiv.org/abs/2106.15282
if random.random() < self.blur_prob:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
# allow for drawing a random sigma between lo and hi float values
if isinstance(blur_sigma, tuple):
blur_sigma = tuple(map(float, blur_sigma))
blur_sigma = random.uniform(*blur_sigma)
# allow for drawing a random kernel size between lo and hi int values
if isinstance(blur_kernel_size, tuple):
blur_kernel_size = tuple(map(int, blur_kernel_size))
kernel_size_lo, kernel_size_hi = blur_kernel_size
@@ -1958,7 +2007,6 @@ class LowresConditioner(nn.Module):
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range)
return cond_fmap
class Decoder(nn.Module):
@@ -1982,6 +2030,7 @@ class Decoder(nn.Module):
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution
blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time
blur_sigma = 0.6, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
clip_denoised = True,
@@ -2130,9 +2179,12 @@ class Decoder(nn.Module):
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.lowres_downsample_mode_nearest = lowres_downsample_mode_nearest
self.to_lowres_cond = LowresConditioner(
downsample_first = lowres_downsample_first,
downsample_mode_nearest = lowres_downsample_mode_nearest,
blur_prob = blur_prob,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
input_image_range = self.input_image_range
@@ -2296,6 +2348,9 @@ class Decoder(nn.Module):
img = torch.randn(shape, device = device)
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
@@ -2448,7 +2503,7 @@ class Decoder(nn.Module):
shape = (batch_size, channel, image_size, image_size)
if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = self.lowres_downsample_mode_nearest)
is_latent_diffusion = isinstance(vae, VQGanVAE)
image_size = vae.get_encoded_fmap_size(image_size)

View File

@@ -129,6 +129,7 @@ class AdapterConfig(BaseModel):
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
max_text_len: int = None
num_timesteps: int = None
num_time_embeds: int = 1
num_image_embeds: int = 1
@@ -136,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
norm_in: bool = False
norm_out: bool = True
attn_dropout: float = 0.
ff_dropout: float = 0.

View File

@@ -673,8 +673,14 @@ class DecoderTrainer(nn.Module):
def sample(self, *args, **kwargs):
distributed = self.accelerator.num_processes > 1
base_decoder = self.accelerator.unwrap_model(self.decoder)
was_training = base_decoder.training
base_decoder.eval()
if kwargs.pop('use_non_ema', False) or not self.use_ema:
return base_decoder.sample(*args, **kwargs, distributed = distributed)
out = base_decoder.sample(*args, **kwargs, distributed = distributed)
base_decoder.train(was_training)
return out
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
@@ -687,6 +693,7 @@ class DecoderTrainer(nn.Module):
for ema in self.ema_unets:
ema.restore_ema_model_device()
base_decoder.train(was_training)
return output
@torch.no_grad()

View File

@@ -1 +1 @@
__version__ = '0.22.1'
__version__ = '0.23.7'

View File

@@ -323,7 +323,7 @@ def train(
last_snapshot = sample
if next_task == 'train':
for i, (img, emb, txt) in enumerate(trainer.train_loader):
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
@@ -358,6 +358,7 @@ def train(
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet)
trainer.update(unet_number=unet)
@@ -416,7 +417,7 @@ def train(
timer = Timer()
accelerator.wait_for_everyone()
i = 0
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()