mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 03:54:35 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
863f4ef243 |
@@ -1122,6 +1122,7 @@ class Decoder(nn.Module):
|
||||
loss_type = 'l1',
|
||||
beta_schedule = 'cosine',
|
||||
predict_x0 = False,
|
||||
predict_x0_for_latent_diffusion = False,
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
@@ -1172,7 +1173,7 @@ class Decoder(nn.Module):
|
||||
|
||||
# predict x0 config
|
||||
|
||||
self.predict_x0 = cast_tuple(predict_x0, len(unets))
|
||||
self.predict_x0 = cast_tuple(predict_x0, len(unets)) if not predict_x0_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||
|
||||
# cascading ddpm related stuff
|
||||
|
||||
|
||||
Reference in New Issue
Block a user