complete ddim integration of diffusion prior as well as decoder for each unet, feature complete for https://github.com/lucidrains/DALLE2-pytorch/issues/157

This commit is contained in:
Phil Wang
2022-07-09 17:25:34 -07:00
parent 68de937aac
commit b7e22f7da0
5 changed files with 186 additions and 34 deletions

View File

@@ -154,6 +154,7 @@ class DiffusionPriorConfig(BaseModel):
image_size: int
image_channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[int] = None
cond_drop_prob: float = 0.
loss_type: str = 'l2'
predict_x_start: bool = True
@@ -233,6 +234,7 @@ class DecoderConfig(BaseModel):
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable(int)] = None
loss_type: str = 'l2'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True