Compare commits

..

4 Commits

3 changed files with 18 additions and 2 deletions

View File

@@ -1704,10 +1704,12 @@ class LowresConditioner(nn.Module):
# allow for drawing a random sigma between lo and hi float values
if isinstance(blur_sigma, tuple):
blur_sigma = tuple(map(float, blur_sigma))
blur_sigma = random.uniform(*blur_sigma)
# allow for drawing a random kernel size between lo and hi int values
if isinstance(blur_kernel_size, tuple):
blur_kernel_size = tuple(map(int, blur_kernel_size))
kernel_size_lo, kernel_size_hi = blur_kernel_size
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
@@ -1743,6 +1745,7 @@ class Decoder(BaseGaussianDiffusion):
clip_x_start = True,
clip_adapter_overrides = dict(),
learned_variance = True,
learned_variance_constrain_frac = False,
vb_loss_weight = 0.001,
unconditional = False,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
@@ -1803,6 +1806,7 @@ class Decoder(BaseGaussianDiffusion):
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
self.learned_variance = learned_variance
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
self.vb_loss_weight = vb_loss_weight
# construct unets and vaes
@@ -1943,6 +1947,9 @@ class Decoder(BaseGaussianDiffusion):
max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
if self.learned_variance_constrain_frac:
var_interp_frac = var_interp_frac.sigmoid()
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp()

View File

@@ -1 +1 @@
__version__ = '0.6.7'
__version__ = '0.6.9'

View File

@@ -4,6 +4,7 @@ from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import resize_image_to
import torchvision
import torch
@@ -136,6 +137,14 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
# training images may be larger than the generated one
if real_image_size > generated_image_size:
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
@@ -322,7 +331,7 @@ def train(
sample = 0
average_loss = 0
timer = Timer()
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
for i, (img, emb, *_) in enumerate(dataloaders["val"]):
sample += img.shape[0]
img, emb = send_to_device((img, emb))