From 51361c2d15478ab8b6b9a240e1940fda74e0914a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 17 Apr 2022 15:19:33 +0200 Subject: [PATCH] added beta_schedule argument --- dalle2_pytorch/dalle2_pytorch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 80a7039..686db85 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -969,9 +969,10 @@ class Decoder(nn.Module): net, *, clip, - timesteps = 1000, - cond_drop_prob = 0.2, - loss_type = 'l1' + timesteps=1000, + cond_drop_prob=0.2, + loss_type="l1", + beta_schedule="cosine", ): super().__init__() assert isinstance(clip, CLIP)