allow for random blur sigma and kernel size augmentations on low res conditioning (need to reread paper to see if the augmentation value needs to be fed into the unet for conditioning as well)

This commit is contained in:
Phil Wang
2022-06-02 11:11:18 -07:00
parent 1cc288af39
commit 38cd62010c
2 changed files with 14 additions and 3 deletions

View File

@@ -1,4 +1,5 @@
import math
import random
from tqdm import tqdm
from inspect import isfunction
from functools import partial, wraps
@@ -1676,7 +1677,7 @@ class LowresConditioner(nn.Module):
def __init__(
self,
downsample_first = True,
blur_sigma = 0.1,
blur_sigma = (0.1, 0.2),
blur_kernel_size = 3,
):
super().__init__()
@@ -1700,6 +1701,16 @@ class LowresConditioner(nn.Module):
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
# allow for drawing a random sigma between lo and hi float values
if isinstance(blur_sigma, tuple):
blur_sigma = random.uniform(*blur_sigma)
# allow for drawing a random kernel size between lo and hi int values
if isinstance(blur_kernel_size, tuple):
kernel_size_lo, kernel_size_hi = blur_kernel_size
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size)
@@ -1725,7 +1736,7 @@ class Decoder(BaseGaussianDiffusion):
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True,

View File

@@ -1 +1 @@
__version__ = '0.6.6'
__version__ = '0.6.7'