mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 20:34:22 +01:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce4b0107c1 | ||
|
|
64c2f9c4eb | ||
|
|
22cc613278 | ||
|
|
83517849e5 | ||
|
|
708809ed6c | ||
|
|
9cc475f6e7 | ||
|
|
ffd342e9d0 | ||
|
|
f8bfd3493a | ||
|
|
9025345e29 | ||
|
|
8cc278447e |
@@ -1704,10 +1704,12 @@ class LowresConditioner(nn.Module):
|
|||||||
|
|
||||||
# allow for drawing a random sigma between lo and hi float values
|
# allow for drawing a random sigma between lo and hi float values
|
||||||
if isinstance(blur_sigma, tuple):
|
if isinstance(blur_sigma, tuple):
|
||||||
|
blur_sigma = tuple(map(float, blur_sigma))
|
||||||
blur_sigma = random.uniform(*blur_sigma)
|
blur_sigma = random.uniform(*blur_sigma)
|
||||||
|
|
||||||
# allow for drawing a random kernel size between lo and hi int values
|
# allow for drawing a random kernel size between lo and hi int values
|
||||||
if isinstance(blur_kernel_size, tuple):
|
if isinstance(blur_kernel_size, tuple):
|
||||||
|
blur_kernel_size = tuple(map(int, blur_kernel_size))
|
||||||
kernel_size_lo, kernel_size_hi = blur_kernel_size
|
kernel_size_lo, kernel_size_hi = blur_kernel_size
|
||||||
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
|
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
|
||||||
|
|
||||||
@@ -1743,6 +1745,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
clip_x_start = True,
|
clip_x_start = True,
|
||||||
clip_adapter_overrides = dict(),
|
clip_adapter_overrides = dict(),
|
||||||
learned_variance = True,
|
learned_variance = True,
|
||||||
|
learned_variance_constrain_frac = False,
|
||||||
vb_loss_weight = 0.001,
|
vb_loss_weight = 0.001,
|
||||||
unconditional = False,
|
unconditional = False,
|
||||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||||
@@ -1803,6 +1806,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
||||||
self.learned_variance = learned_variance
|
self.learned_variance = learned_variance
|
||||||
|
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
|
||||||
self.vb_loss_weight = vb_loss_weight
|
self.vb_loss_weight = vb_loss_weight
|
||||||
|
|
||||||
# construct unets and vaes
|
# construct unets and vaes
|
||||||
@@ -1943,6 +1947,9 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
max_log = extract(torch.log(self.betas), t, x.shape)
|
max_log = extract(torch.log(self.betas), t, x.shape)
|
||||||
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
||||||
|
|
||||||
|
if self.learned_variance_constrain_frac:
|
||||||
|
var_interp_frac = var_interp_frac.sigmoid()
|
||||||
|
|
||||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||||
posterior_variance = posterior_log_variance.exp()
|
posterior_variance = posterior_log_variance.exp()
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ def get_optimizer(
|
|||||||
params,
|
params,
|
||||||
lr = 1e-4,
|
lr = 1e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.99),
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
filter_by_requires_grad = False,
|
filter_by_requires_grad = False,
|
||||||
group_wd_params = True,
|
group_wd_params = True,
|
||||||
|
|||||||
@@ -175,12 +175,34 @@ def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embe
|
|||||||
# exponential moving average wrapper
|
# exponential moving average wrapper
|
||||||
|
|
||||||
class EMA(nn.Module):
|
class EMA(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements exponential moving average shadowing for your model.
|
||||||
|
|
||||||
|
Utilizes an inverse decay schedule to manage longer term training runs.
|
||||||
|
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
|
||||||
|
|
||||||
|
@crowsonkb's notes on EMA Warmup:
|
||||||
|
|
||||||
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
|
||||||
|
good values for models you plan to train for a million or more steps (reaches decay
|
||||||
|
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
|
||||||
|
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||||
|
215.4k steps).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||||
|
power (float): Exponential factor of EMA warmup. Default: 1.
|
||||||
|
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
beta = 0.99,
|
beta = 0.9999,
|
||||||
update_after_step = 1000,
|
update_after_step = 10000,
|
||||||
update_every = 10,
|
update_every = 10,
|
||||||
|
inv_gamma = 1.0,
|
||||||
|
power = 2/3,
|
||||||
|
min_value = 0.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
@@ -188,7 +210,11 @@ class EMA(nn.Module):
|
|||||||
self.ema_model = copy.deepcopy(model)
|
self.ema_model = copy.deepcopy(model)
|
||||||
|
|
||||||
self.update_every = update_every
|
self.update_every = update_every
|
||||||
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
self.update_after_step = update_after_step
|
||||||
|
|
||||||
|
self.inv_gamma = inv_gamma
|
||||||
|
self.power = power
|
||||||
|
self.min_value = min_value
|
||||||
|
|
||||||
self.register_buffer('initted', torch.Tensor([False]))
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
self.register_buffer('step', torch.tensor([0]))
|
self.register_buffer('step', torch.tensor([0]))
|
||||||
@@ -198,37 +224,44 @@ class EMA(nn.Module):
|
|||||||
self.ema_model.to(device)
|
self.ema_model.to(device)
|
||||||
|
|
||||||
def copy_params_from_model_to_ema(self):
|
def copy_params_from_model_to_ema(self):
|
||||||
self.ema_model.state_dict(self.online_model.state_dict())
|
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
||||||
|
ma_param.data.copy_(current_param.data)
|
||||||
|
|
||||||
|
def get_current_decay(self):
|
||||||
|
epoch = max(0, self.step.item() - self.update_after_step - 1)
|
||||||
|
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
||||||
|
return 0. if epoch < 0 else min(self.beta, max(self.min_value, value))
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
step = self.step.item()
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
if (self.step % self.update_every) != 0:
|
if (step % self.update_every) != 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.step <= self.update_after_step:
|
if step <= self.update_after_step:
|
||||||
self.copy_params_from_model_to_ema()
|
self.copy_params_from_model_to_ema()
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.initted:
|
if not self.initted.item():
|
||||||
self.copy_params_from_model_to_ema()
|
self.copy_params_from_model_to_ema()
|
||||||
self.initted.data.copy_(torch.Tensor([True]))
|
self.initted.data.copy_(torch.Tensor([True]))
|
||||||
|
|
||||||
self.update_moving_average(self.ema_model, self.online_model)
|
self.update_moving_average(self.ema_model, self.online_model)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def update_moving_average(self, ma_model, current_model):
|
def update_moving_average(self, ma_model, current_model):
|
||||||
def calculate_ema(beta, old, new):
|
current_decay = self.get_current_decay()
|
||||||
if not exists(old):
|
|
||||||
return new
|
|
||||||
return old * beta + (1 - beta) * new
|
|
||||||
|
|
||||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
|
||||||
old_weight, up_weight = ma_params.data, current_params.data
|
difference = ma_params.data - current_params.data
|
||||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
difference.mul_(1.0 - current_decay)
|
||||||
|
ma_params.sub_(difference)
|
||||||
|
|
||||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
|
||||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
difference = ma_buffer - current_buffer
|
||||||
ma_buffer.copy_(new_buffer_value)
|
difference.mul_(1.0 - current_decay)
|
||||||
|
ma_buffer.sub_(difference)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.ema_model(*args, **kwargs)
|
return self.ema_model(*args, **kwargs)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.6.7'
|
__version__ = '0.6.13'
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
|||||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||||
|
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch
|
import torch
|
||||||
@@ -136,6 +137,14 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
|
|||||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||||
"""
|
"""
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
||||||
|
|
||||||
|
real_image_size = real_images[0].shape[-1]
|
||||||
|
generated_image_size = generated_images[0].shape[-1]
|
||||||
|
|
||||||
|
# training images may be larger than the generated one
|
||||||
|
if real_image_size > generated_image_size:
|
||||||
|
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
|
||||||
|
|
||||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||||
return grid_images, captions
|
return grid_images, captions
|
||||||
|
|
||||||
@@ -322,7 +331,7 @@ def train(
|
|||||||
sample = 0
|
sample = 0
|
||||||
average_loss = 0
|
average_loss = 0
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
for i, (img, emb, *_) in enumerate(dataloaders["val"]):
|
||||||
sample += img.shape[0]
|
sample += img.shape[0]
|
||||||
img, emb = send_to_device((img, emb))
|
img, emb = send_to_device((img, emb))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user