mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 01:34:19 +01:00
add gradient checkpointing for all resnet blocks
This commit is contained in:
@@ -8,6 +8,7 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from torch import nn, einsum
|
||||
import torchvision.transforms as T
|
||||
|
||||
@@ -108,6 +109,28 @@ def pad_tuple_to_length(t, length, fillvalue = None):
|
||||
return t
|
||||
return (*t, *((fillvalue,) * remain_length))
|
||||
|
||||
# checkpointing helper function
|
||||
|
||||
def make_checkpointable(fn, **kwargs):
|
||||
if isinstance(fn, nn.ModuleList):
|
||||
return [maybe(make_checkpointable)(el, **kwargs) for el in fn]
|
||||
|
||||
condition = kwargs.pop('condition', None)
|
||||
|
||||
if exists(condition) and not condition(fn):
|
||||
return fn
|
||||
|
||||
@wraps(fn)
|
||||
def inner(*args):
|
||||
input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])
|
||||
|
||||
if not input_needs_grad:
|
||||
return fn(*args)
|
||||
|
||||
return checkpoint(fn, *args)
|
||||
|
||||
return inner
|
||||
|
||||
# for controlling freezing of CLIP
|
||||
|
||||
def set_module_requires_grad_(module, requires_grad):
|
||||
@@ -1698,6 +1721,7 @@ class Unet(nn.Module):
|
||||
pixel_shuffle_upsample = True,
|
||||
final_conv_kernel_size = 1,
|
||||
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
|
||||
checkpoint_during_training = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1908,6 +1932,10 @@ class Unet(nn.Module):
|
||||
|
||||
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
||||
|
||||
# whether to checkpoint during training
|
||||
|
||||
self.checkpoint_during_training = checkpoint_during_training
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
def cast_model_parameters(
|
||||
@@ -1965,7 +1993,8 @@ class Unet(nn.Module):
|
||||
image_cond_drop_prob = 0.,
|
||||
text_cond_drop_prob = 0.,
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None
|
||||
blur_kernel_size = None,
|
||||
disable_checkpoint = False
|
||||
):
|
||||
batch_size, device = x.shape[0], x.device
|
||||
|
||||
@@ -2087,17 +2116,29 @@ class Unet(nn.Module):
|
||||
c = self.norm_cond(c)
|
||||
mid_c = self.norm_mid_cond(mid_c)
|
||||
|
||||
# gradient checkpointing
|
||||
|
||||
can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint
|
||||
apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity
|
||||
|
||||
# make checkpointable modules
|
||||
|
||||
init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)]
|
||||
|
||||
can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock)
|
||||
downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)]
|
||||
|
||||
# initial resnet block
|
||||
|
||||
if exists(self.init_resnet_block):
|
||||
x = self.init_resnet_block(x, t)
|
||||
if exists(init_resnet_block):
|
||||
x = init_resnet_block(x, t)
|
||||
|
||||
# go through the layers of the unet, down and up
|
||||
|
||||
down_hiddens = []
|
||||
up_hiddens = []
|
||||
|
||||
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
|
||||
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs:
|
||||
if exists(pre_downsample):
|
||||
x = pre_downsample(x)
|
||||
|
||||
@@ -2113,16 +2154,16 @@ class Unet(nn.Module):
|
||||
if exists(post_downsample):
|
||||
x = post_downsample(x)
|
||||
|
||||
x = self.mid_block1(x, t, mid_c)
|
||||
x = mid_block1(x, t, mid_c)
|
||||
|
||||
if exists(self.mid_attn):
|
||||
x = self.mid_attn(x)
|
||||
if exists(mid_attn):
|
||||
x = mid_attn(x)
|
||||
|
||||
x = self.mid_block2(x, t, mid_c)
|
||||
x = mid_block2(x, t, mid_c)
|
||||
|
||||
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
|
||||
|
||||
for init_block, resnet_blocks, attn, upsample in self.ups:
|
||||
for init_block, resnet_blocks, attn, upsample in ups:
|
||||
x = connect_skip(x)
|
||||
x = init_block(x, t, c)
|
||||
|
||||
@@ -2139,7 +2180,7 @@ class Unet(nn.Module):
|
||||
|
||||
x = torch.cat((x, r), dim = 1)
|
||||
|
||||
x = self.final_resnet_block(x, t)
|
||||
x = final_resnet_block(x, t)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.4.6'
|
||||
__version__ = '1.5.0'
|
||||
|
||||
Reference in New Issue
Block a user