Compare commits

...

7 Commits

5 changed files with 34 additions and 20 deletions

View File

@@ -1745,6 +1745,7 @@ class Decoder(BaseGaussianDiffusion):
clip_x_start = True,
clip_adapter_overrides = dict(),
learned_variance = True,
learned_variance_constrain_frac = False,
vb_loss_weight = 0.001,
unconditional = False,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
@@ -1805,6 +1806,7 @@ class Decoder(BaseGaussianDiffusion):
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
self.learned_variance = learned_variance
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
self.vb_loss_weight = vb_loss_weight
# construct unets and vaes
@@ -1945,6 +1947,9 @@ class Decoder(BaseGaussianDiffusion):
max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
if self.learned_variance_constrain_frac:
var_interp_frac = var_interp_frac.sigmoid()
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp()

View File

@@ -11,7 +11,7 @@ def get_optimizer(
params,
lr = 1e-4,
wd = 1e-2,
betas = (0.9, 0.999),
betas = (0.9, 0.99),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True,

View File

@@ -178,7 +178,7 @@ class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.99,
beta = 0.9999,
update_after_step = 1000,
update_every = 10,
):
@@ -188,7 +188,7 @@ class EMA(nn.Module):
self.ema_model = copy.deepcopy(model)
self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.update_after_step = update_after_step
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0]))
@@ -198,37 +198,37 @@ class EMA(nn.Module):
self.ema_model.to(device)
def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict())
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
ma_param.data.copy_(current_param.data)
def update(self):
step = self.step.item()
self.step += 1
if (self.step % self.update_every) != 0:
if (step % self.update_every) != 0:
return
if self.step <= self.update_after_step:
if step <= self.update_after_step:
self.copy_params_from_model_to_ema()
return
if not self.initted:
if not self.initted.item():
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
@torch.no_grad()
def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new):
if not exists(old):
return new
return old * beta + (1 - beta) * new
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
difference = ma_params.data - current_params.data
difference.mul_(1.0 - self.beta)
ma_params.sub_(difference)
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
difference = ma_buffer - current_buffer
difference.mul_(1.0 - self.beta)
ma_buffer.sub_(difference)
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)

View File

@@ -1 +1 @@
__version__ = '0.6.8'
__version__ = '0.6.12'

View File

@@ -4,6 +4,7 @@ from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import resize_image_to
import torchvision
import torch
@@ -136,6 +137,14 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
# training images may be larger than the generated one
if real_image_size > generated_image_size:
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
@@ -322,7 +331,7 @@ def train(
sample = 0
average_loss = 0
timer = Timer()
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
for i, (img, emb, *_) in enumerate(dataloaders["val"]):
sample += img.shape[0]
img, emb = send_to_device((img, emb))