Compare commits

..

2 Commits

Author SHA1 Message Date
Phil Wang
6647050c33 fix missing resisidual for highest resolution of the unet 2022-06-15 18:01:19 -07:00
Giorgos Zachariadis
b4c3e5b854 changed str in order to avoid confusions and collisions with Python (#147) 2022-06-15 13:41:16 -07:00
3 changed files with 10 additions and 5 deletions

View File

@@ -1489,8 +1489,10 @@ class Unet(nn.Module):
Upsample(dim_in)
]))
final_dim_in = dim * (1 if memory_efficient else 2)
self.final_conv = nn.Sequential(
ResnetBlock(dim, dim, groups = resnet_groups[0]),
ResnetBlock(final_dim_in, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, self.channels_out, 1)
)
@@ -1682,7 +1684,7 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t)
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = torch.cat((x, hiddens.pop()), dim = 1)
x = init_block(x, c, t)
x = sparse_attn(x)
@@ -1691,6 +1693,9 @@ class Unet(nn.Module):
x = upsample(x)
if len(hiddens) > 0:
x = torch.cat((x, hiddens.pop()), dim = 1)
return self.final_conv(x)
class LowresConditioner(nn.Module):

View File

@@ -1 +1 @@
__version__ = '0.8.1'
__version__ = '0.9.0'

View File

@@ -68,8 +68,8 @@ def group_dict_by_key(cond, d):
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def string_begins_with(prefix, string_input):
return string_input.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)