diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 80a7039..686db85 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -969,9 +969,10 @@ class Decoder(nn.Module): net, *, clip, - timesteps = 1000, - cond_drop_prob = 0.2, - loss_type = 'l1' + timesteps=1000, + cond_drop_prob=0.2, + loss_type="l1", + beta_schedule="cosine", ): super().__init__() assert isinstance(clip, CLIP)