mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
just take care of the logic for setting all latent diffusion to predict x0, if needed
This commit is contained in:
@@ -1122,6 +1122,7 @@ class Decoder(nn.Module):
|
||||
loss_type = 'l1',
|
||||
beta_schedule = 'cosine',
|
||||
predict_x0 = False,
|
||||
predict_x0_for_latent_diffusion = False,
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
@@ -1172,7 +1173,7 @@ class Decoder(nn.Module):
|
||||
|
||||
# predict x0 config
|
||||
|
||||
self.predict_x0 = cast_tuple(predict_x0, len(unets))
|
||||
self.predict_x0 = cast_tuple(predict_x0, len(unets)) if not predict_x0_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||
|
||||
# cascading ddpm related stuff
|
||||
|
||||
|
||||
Reference in New Issue
Block a user