mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
just take care of the logic for setting all latent diffusion to predict x0, if needed
This commit is contained in:
@@ -1122,6 +1122,7 @@ class Decoder(nn.Module):
|
|||||||
loss_type = 'l1',
|
loss_type = 'l1',
|
||||||
beta_schedule = 'cosine',
|
beta_schedule = 'cosine',
|
||||||
predict_x0 = False,
|
predict_x0 = False,
|
||||||
|
predict_x0_for_latent_diffusion = False,
|
||||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||||
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
||||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||||
@@ -1172,7 +1173,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
# predict x0 config
|
# predict x0 config
|
||||||
|
|
||||||
self.predict_x0 = cast_tuple(predict_x0, len(unets))
|
self.predict_x0 = cast_tuple(predict_x0, len(unets)) if not predict_x0_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||||
|
|
||||||
# cascading ddpm related stuff
|
# cascading ddpm related stuff
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user