add gradient checkpointing for all resnet blocks

This commit is contained in:
Phil Wang
2022-08-02 19:21:44 -07:00
parent 451de34871
commit be3bb868bf
2 changed files with 52 additions and 11 deletions

View File

@@ -8,6 +8,7 @@ from pathlib import Path
import torch
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch import nn, einsum
import torchvision.transforms as T
@@ -108,6 +109,28 @@ def pad_tuple_to_length(t, length, fillvalue = None):
return t
return (*t, *((fillvalue,) * remain_length))
# checkpointing helper function
def make_checkpointable(fn, **kwargs):
if isinstance(fn, nn.ModuleList):
return [maybe(make_checkpointable)(el, **kwargs) for el in fn]
condition = kwargs.pop('condition', None)
if exists(condition) and not condition(fn):
return fn
@wraps(fn)
def inner(*args):
input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])
if not input_needs_grad:
return fn(*args)
return checkpoint(fn, *args)
return inner
# for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad):
@@ -1698,6 +1721,7 @@ class Unet(nn.Module):
pixel_shuffle_upsample = True,
final_conv_kernel_size = 1,
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
checkpoint_during_training = False,
**kwargs
):
super().__init__()
@@ -1908,6 +1932,10 @@ class Unet(nn.Module):
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
# whether to checkpoint during training
self.checkpoint_during_training = checkpoint_during_training
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters(
@@ -1965,7 +1993,8 @@ class Unet(nn.Module):
image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
blur_kernel_size = None,
disable_checkpoint = False
):
batch_size, device = x.shape[0], x.device
@@ -2087,17 +2116,29 @@ class Unet(nn.Module):
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)
# gradient checkpointing
can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint
apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity
# make checkpointable modules
init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)]
can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock)
downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)]
# initial resnet block
if exists(self.init_resnet_block):
x = self.init_resnet_block(x, t)
if exists(init_resnet_block):
x = init_resnet_block(x, t)
# go through the layers of the unet, down and up
down_hiddens = []
up_hiddens = []
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs:
if exists(pre_downsample):
x = pre_downsample(x)
@@ -2113,16 +2154,16 @@ class Unet(nn.Module):
if exists(post_downsample):
x = post_downsample(x)
x = self.mid_block1(x, t, mid_c)
x = mid_block1(x, t, mid_c)
if exists(self.mid_attn):
x = self.mid_attn(x)
if exists(mid_attn):
x = mid_attn(x)
x = self.mid_block2(x, t, mid_c)
x = mid_block2(x, t, mid_c)
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
for init_block, resnet_blocks, attn, upsample in self.ups:
for init_block, resnet_blocks, attn, upsample in ups:
x = connect_skip(x)
x = init_block(x, t, c)
@@ -2139,7 +2180,7 @@ class Unet(nn.Module):
x = torch.cat((x, r), dim = 1)
x = self.final_resnet_block(x, t)
x = final_resnet_block(x, t)
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)

View File

@@ -1 +1 @@
__version__ = '1.4.6'
__version__ = '1.5.0'