mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
add gradient checkpointing for all resnet blocks
This commit is contained in:
@@ -8,6 +8,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
|
||||||
@@ -108,6 +109,28 @@ def pad_tuple_to_length(t, length, fillvalue = None):
|
|||||||
return t
|
return t
|
||||||
return (*t, *((fillvalue,) * remain_length))
|
return (*t, *((fillvalue,) * remain_length))
|
||||||
|
|
||||||
|
# checkpointing helper function
|
||||||
|
|
||||||
|
def make_checkpointable(fn, **kwargs):
|
||||||
|
if isinstance(fn, nn.ModuleList):
|
||||||
|
return [maybe(make_checkpointable)(el, **kwargs) for el in fn]
|
||||||
|
|
||||||
|
condition = kwargs.pop('condition', None)
|
||||||
|
|
||||||
|
if exists(condition) and not condition(fn):
|
||||||
|
return fn
|
||||||
|
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(*args):
|
||||||
|
input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])
|
||||||
|
|
||||||
|
if not input_needs_grad:
|
||||||
|
return fn(*args)
|
||||||
|
|
||||||
|
return checkpoint(fn, *args)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
# for controlling freezing of CLIP
|
# for controlling freezing of CLIP
|
||||||
|
|
||||||
def set_module_requires_grad_(module, requires_grad):
|
def set_module_requires_grad_(module, requires_grad):
|
||||||
@@ -1698,6 +1721,7 @@ class Unet(nn.Module):
|
|||||||
pixel_shuffle_upsample = True,
|
pixel_shuffle_upsample = True,
|
||||||
final_conv_kernel_size = 1,
|
final_conv_kernel_size = 1,
|
||||||
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
|
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
|
||||||
|
checkpoint_during_training = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1908,6 +1932,10 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
||||||
|
|
||||||
|
# whether to checkpoint during training
|
||||||
|
|
||||||
|
self.checkpoint_during_training = checkpoint_during_training
|
||||||
|
|
||||||
# if the current settings for the unet are not correct
|
# if the current settings for the unet are not correct
|
||||||
# for cascading DDPM, then reinit the unet with the right settings
|
# for cascading DDPM, then reinit the unet with the right settings
|
||||||
def cast_model_parameters(
|
def cast_model_parameters(
|
||||||
@@ -1965,7 +1993,8 @@ class Unet(nn.Module):
|
|||||||
image_cond_drop_prob = 0.,
|
image_cond_drop_prob = 0.,
|
||||||
text_cond_drop_prob = 0.,
|
text_cond_drop_prob = 0.,
|
||||||
blur_sigma = None,
|
blur_sigma = None,
|
||||||
blur_kernel_size = None
|
blur_kernel_size = None,
|
||||||
|
disable_checkpoint = False
|
||||||
):
|
):
|
||||||
batch_size, device = x.shape[0], x.device
|
batch_size, device = x.shape[0], x.device
|
||||||
|
|
||||||
@@ -2087,17 +2116,29 @@ class Unet(nn.Module):
|
|||||||
c = self.norm_cond(c)
|
c = self.norm_cond(c)
|
||||||
mid_c = self.norm_mid_cond(mid_c)
|
mid_c = self.norm_mid_cond(mid_c)
|
||||||
|
|
||||||
|
# gradient checkpointing
|
||||||
|
|
||||||
|
can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint
|
||||||
|
apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity
|
||||||
|
|
||||||
|
# make checkpointable modules
|
||||||
|
|
||||||
|
init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)]
|
||||||
|
|
||||||
|
can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock)
|
||||||
|
downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)]
|
||||||
|
|
||||||
# initial resnet block
|
# initial resnet block
|
||||||
|
|
||||||
if exists(self.init_resnet_block):
|
if exists(init_resnet_block):
|
||||||
x = self.init_resnet_block(x, t)
|
x = init_resnet_block(x, t)
|
||||||
|
|
||||||
# go through the layers of the unet, down and up
|
# go through the layers of the unet, down and up
|
||||||
|
|
||||||
down_hiddens = []
|
down_hiddens = []
|
||||||
up_hiddens = []
|
up_hiddens = []
|
||||||
|
|
||||||
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
|
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs:
|
||||||
if exists(pre_downsample):
|
if exists(pre_downsample):
|
||||||
x = pre_downsample(x)
|
x = pre_downsample(x)
|
||||||
|
|
||||||
@@ -2113,16 +2154,16 @@ class Unet(nn.Module):
|
|||||||
if exists(post_downsample):
|
if exists(post_downsample):
|
||||||
x = post_downsample(x)
|
x = post_downsample(x)
|
||||||
|
|
||||||
x = self.mid_block1(x, t, mid_c)
|
x = mid_block1(x, t, mid_c)
|
||||||
|
|
||||||
if exists(self.mid_attn):
|
if exists(mid_attn):
|
||||||
x = self.mid_attn(x)
|
x = mid_attn(x)
|
||||||
|
|
||||||
x = self.mid_block2(x, t, mid_c)
|
x = mid_block2(x, t, mid_c)
|
||||||
|
|
||||||
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
|
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
|
||||||
|
|
||||||
for init_block, resnet_blocks, attn, upsample in self.ups:
|
for init_block, resnet_blocks, attn, upsample in ups:
|
||||||
x = connect_skip(x)
|
x = connect_skip(x)
|
||||||
x = init_block(x, t, c)
|
x = init_block(x, t, c)
|
||||||
|
|
||||||
@@ -2139,7 +2180,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
x = torch.cat((x, r), dim = 1)
|
x = torch.cat((x, r), dim = 1)
|
||||||
|
|
||||||
x = self.final_resnet_block(x, t)
|
x = final_resnet_block(x, t)
|
||||||
|
|
||||||
if exists(lowres_cond_img):
|
if exists(lowres_cond_img):
|
||||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.4.6'
|
__version__ = '1.5.0'
|
||||||
|
|||||||
Reference in New Issue
Block a user