Compare commits

..

4 Commits

Author SHA1 Message Date
Phil Wang
95a512cb65 fix a potential bug with conditioning with blurred low resolution image, blur should be applied only 50% of the time 2022-07-13 10:11:49 -07:00
Phil Wang
972ee973bc fix issue with ddim and normalization of lowres conditioning image 2022-07-13 09:48:40 -07:00
Phil Wang
79e2a3bc77 only use the stable layernorm for final output norm in transformer 2022-07-13 07:56:30 -07:00
Aidan Dempster
544cdd0b29 Reverted to using basic dataloaders (#205)
Accelerate removes the ability to collate strings. Likely since it
cannot gather strings.
2022-07-12 18:22:27 -07:00
3 changed files with 35 additions and 13 deletions

View File

@@ -527,25 +527,31 @@ class NoiseScheduler(nn.Module):
# diffusion prior
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
def __init__(self, dim, eps = 1e-5, stable = False):
super().__init__()
self.eps = eps
self.stable = stable
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x / x.amax(dim = -1, keepdim = True).detach()
if self.stable:
x = x / x.amax(dim = -1, keepdim = True).detach()
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
def __init__(self, dim, eps = 1e-5, stable = False):
super().__init__()
self.eps = eps
self.stable = stable
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
x = x / x.amax(dim = 1, keepdim = True).detach()
if self.stable:
x = x / x.amax(dim = 1, keepdim = True).detach()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
@@ -669,7 +675,7 @@ class Attention(nn.Module):
dropout = 0.,
causal = False,
rotary_emb = None,
pb_relax_alpha = 32 ** 2
pb_relax_alpha = 128
):
super().__init__()
self.pb_relax_alpha = pb_relax_alpha
@@ -782,7 +788,7 @@ class CausalTransformer(nn.Module):
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
def forward(self, x):
@@ -1940,6 +1946,7 @@ class LowresConditioner(nn.Module):
self,
downsample_first = True,
downsample_mode_nearest = False,
blur_prob = 0.5,
blur_sigma = 0.6,
blur_kernel_size = 3,
input_image_range = None
@@ -1950,6 +1957,7 @@ class LowresConditioner(nn.Module):
self.input_image_range = input_image_range
self.blur_prob = blur_prob
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
@@ -1962,20 +1970,27 @@ class LowresConditioner(nn.Module):
blur_sigma = None,
blur_kernel_size = None
):
if self.training and self.downsample_first and exists(downsample_image_size):
if self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest)
if self.training:
# blur is only applied 50% of the time
# section 3.1 in https://arxiv.org/abs/2106.15282
if random.random() < self.blur_prob:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
# allow for drawing a random sigma between lo and hi float values
if isinstance(blur_sigma, tuple):
blur_sigma = tuple(map(float, blur_sigma))
blur_sigma = random.uniform(*blur_sigma)
# allow for drawing a random kernel size between lo and hi int values
if isinstance(blur_kernel_size, tuple):
blur_kernel_size = tuple(map(int, blur_kernel_size))
kernel_size_lo, kernel_size_hi = blur_kernel_size
@@ -1984,7 +1999,6 @@ class LowresConditioner(nn.Module):
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range)
return cond_fmap
class Decoder(nn.Module):
@@ -2008,6 +2022,7 @@ class Decoder(nn.Module):
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution
blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time
blur_sigma = 0.6, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
clip_denoised = True,
@@ -2156,9 +2171,12 @@ class Decoder(nn.Module):
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.lowres_downsample_mode_nearest = lowres_downsample_mode_nearest
self.to_lowres_cond = LowresConditioner(
downsample_first = lowres_downsample_first,
downsample_mode_nearest = lowres_downsample_mode_nearest,
blur_prob = blur_prob,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
input_image_range = self.input_image_range
@@ -2322,6 +2340,9 @@ class Decoder(nn.Module):
img = torch.randn(shape, device = device)
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
@@ -2474,7 +2495,7 @@ class Decoder(nn.Module):
shape = (batch_size, channel, image_size, image_size)
if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = self.lowres_downsample_mode_nearest)
is_latent_diffusion = isinstance(vae, VQGanVAE)
image_size = vae.get_encoded_fmap_size(image_size)

View File

@@ -1 +1 @@
__version__ = '0.23.2'
__version__ = '0.23.5'

View File

@@ -323,7 +323,7 @@ def train(
last_snapshot = sample
if next_task == 'train':
for i, (img, emb, txt) in enumerate(trainer.train_loader):
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
@@ -358,6 +358,7 @@ def train(
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet)
trainer.update(unet_number=unet)
@@ -416,7 +417,7 @@ def train(
timer = Timer()
accelerator.wait_for_everyone()
i = 0
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()