Compare commits

..

5 Commits
1.0.5 ... 1.2.1

Author SHA1 Message Date
Phil Wang
36fb46a95e fix readme and a small bug in DALLE2 class 2022-07-28 08:33:51 -07:00
Phil Wang
07abfcf45b rescale values in linear attention to mitigate overflows in fp16 setting 2022-07-27 12:27:38 -07:00
Phil Wang
2e35a9967d product management 2022-07-26 11:10:16 -07:00
Phil Wang
406e75043f add upsample combiner feature for the unets 2022-07-26 10:46:04 -07:00
Phil Wang
9646dfc0e6 fix path_or_state bug 2022-07-26 09:47:54 -07:00
4 changed files with 65 additions and 10 deletions

View File

@@ -371,6 +371,7 @@ loss.backward()
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8),
@@ -1112,7 +1113,8 @@ For detailed information on training the diffusion prior, please refer to the [d
- [x] allow for unet to be able to condition non-cross attention style as well
- [x] speed up inference, read up on papers (ddim)
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
## Citations

View File

@@ -1503,6 +1503,7 @@ class LinearAttention(nn.Module):
k = k.softmax(dim = -2)
q = q * self.scale
v = v / (x * y)
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
@@ -1538,6 +1539,38 @@ class CrossEmbedLayer(nn.Module):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
class UpsampleCombiner(nn.Module):
def __init__(
self,
dim,
*,
enabled = False,
dim_ins = tuple(),
dim_outs = tuple()
):
super().__init__()
assert len(dim_ins) == len(dim_outs)
self.enabled = enabled
if not self.enabled:
self.dim_out = dim
return
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
def forward(self, x, fmaps = None):
target_size = x.shape[-1]
fmaps = default(fmaps, tuple())
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
return x
fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
return torch.cat((x, *outs), dim = 1)
class Unet(nn.Module):
def __init__(
self,
@@ -1575,6 +1608,7 @@ class Unet(nn.Module):
scale_skip_connection = False,
pixel_shuffle_upsample = True,
final_conv_kernel_size = 1,
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
**kwargs
):
super().__init__()
@@ -1710,7 +1744,8 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
skip_connect_dims = [] # keeping track of skip connection dimensions
skip_connect_dims = [] # keeping track of skip connection dimensions
upsample_combiner_dims = [] # keeping track of dimensions for final upsample feature map combiner
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
is_first = ind == 0
@@ -1752,6 +1787,8 @@ class Unet(nn.Module):
elif sparse_attn:
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
upsample_combiner_dims.append(dim_out)
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
@@ -1759,7 +1796,18 @@ class Unet(nn.Module):
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
]))
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
# whether to combine outputs from all upsample blocks for final resnet block
self.upsample_combiner = UpsampleCombiner(
dim = dim,
enabled = combine_upsample_fmaps,
dim_ins = upsample_combiner_dims,
dim_outs = (dim,) * len(upsample_combiner_dims)
)
# a final resnet block
self.final_resnet_block = ResnetBlock(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
out_dim_in = dim + (channels if lowres_cond else 0)
@@ -1953,7 +2001,8 @@ class Unet(nn.Module):
# go through the layers of the unet, down and up
hiddens = []
down_hiddens = []
up_hiddens = []
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
if exists(pre_downsample):
@@ -1963,10 +2012,10 @@ class Unet(nn.Module):
for resnet_block in resnet_blocks:
x = resnet_block(x, t, c)
hiddens.append(x)
down_hiddens.append(x.contiguous())
x = attn(x)
hiddens.append(x.contiguous())
down_hiddens.append(x.contiguous())
if exists(post_downsample):
x = post_downsample(x)
@@ -1978,7 +2027,7 @@ class Unet(nn.Module):
x = self.mid_block2(x, t, mid_c)
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
for init_block, resnet_blocks, attn, upsample in self.ups:
x = connect_skip(x)
@@ -1989,8 +2038,12 @@ class Unet(nn.Module):
x = resnet_block(x, t, c)
x = attn(x)
up_hiddens.append(x.contiguous())
x = upsample(x)
x = self.upsample_combiner(x, up_hiddens)
x = torch.cat((x, r), dim = 1)
x = self.final_resnet_block(x, t)
@@ -2885,7 +2938,7 @@ class DALLE2(nn.Module):
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)
if return_pil_images:
images = list(map(self.to_pil, images.unbind(dim = 0)))

View File

@@ -300,7 +300,7 @@ class DiffusionPriorTrainer(nn.Module):
# all processes need to load checkpoint. no restriction here
if isinstance(path_or_state, str):
path = Path(path)
path = Path(path_or_state)
assert path.exists()
loaded_obj = torch.load(str(path), map_location=self.device)

View File

@@ -1 +1 @@
__version__ = '1.0.5'
__version__ = '1.2.1'