just take care of the logic for setting all latent diffusion to predict x0, if needed

This commit is contained in:
Phil Wang
2022-04-24 10:06:42 -07:00
parent fb8a66a2de
commit 863f4ef243
2 changed files with 3 additions and 2 deletions

View File

@@ -1122,6 +1122,7 @@ class Decoder(nn.Module):
loss_type = 'l1', loss_type = 'l1',
beta_schedule = 'cosine', beta_schedule = 'cosine',
predict_x0 = False, predict_x0 = False,
predict_x0_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage image_sizes = None, # for cascading ddpm, image size at each stage
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
@@ -1172,7 +1173,7 @@ class Decoder(nn.Module):
# predict x0 config # predict x0 config
self.predict_x0 = cast_tuple(predict_x0, len(unets)) self.predict_x0 = cast_tuple(predict_x0, len(unets)) if not predict_x0_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
# cascading ddpm related stuff # cascading ddpm related stuff

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.41', version = '0.0.42',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',