From 863f4ef2437b9ba9030c9e532b2b5972e66b3ee6 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 24 Apr 2022 10:06:42 -0700 Subject: [PATCH] just take care of the logic for setting all latent diffusion to predict x0, if needed --- dalle2_pytorch/dalle2_pytorch.py | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 791b6c0..b3793a4 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1122,6 +1122,7 @@ class Decoder(nn.Module): loss_type = 'l1', beta_schedule = 'cosine', predict_x0 = False, + predict_x0_for_latent_diffusion = False, image_sizes = None, # for cascading ddpm, image size at each stage lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur @@ -1172,7 +1173,7 @@ class Decoder(nn.Module): # predict x0 config - self.predict_x0 = cast_tuple(predict_x0, len(unets)) + self.predict_x0 = cast_tuple(predict_x0, len(unets)) if not predict_x0_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes)) # cascading ddpm related stuff diff --git a/setup.py b/setup.py index bc3cdc7..c00a455 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.41', + version = '0.0.42', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',