Compare commits

...

5 Commits

3 changed files with 12 additions and 4 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1 +1 @@
github: [lucidrains]
github: [nousr, Veldrovive, lucidrains]

View File

@@ -1731,7 +1731,10 @@ class Unet(nn.Module):
]))
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
out_dim_in = dim + (channels if lowres_cond else 0)
self.to_out = nn.Conv2d(out_dim_in, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
@@ -1923,7 +1926,7 @@ class Unet(nn.Module):
hiddens.append(x)
x = attn(x)
hiddens.append(x)
hiddens.append(x.contiguous())
if exists(post_downsample):
x = post_downsample(x)
@@ -1951,6 +1954,10 @@ class Unet(nn.Module):
x = torch.cat((x, r), dim = 1)
x = self.final_resnet_block(x, t)
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
return self.to_out(x)
class LowresConditioner(nn.Module):
@@ -2165,6 +2172,7 @@ class Decoder(nn.Module):
# random crop sizes (for super-resoluting unets at the end of cascade?)
self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))
assert not exists(self.random_crop_sizes[0]), 'you would not need to randomly crop the image for the base unet'
# predict x0 config

View File

@@ -1 +1 @@
__version__ = '0.23.10'
__version__ = '0.24.2'