diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 41b0db7..700b18a 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -285,17 +285,16 @@ class DiffusionPriorNetwork(nn.Module): def forward_with_cond_scale( self, - x, *args, cond_scale = 1., **kwargs ): - logits = self.forward(x, *args, **kwargs) + logits = self.forward(*args, **kwargs) if cond_scale == 1: return logits - null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs) + null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( @@ -675,17 +674,16 @@ class Unet(nn.Module): def forward_with_cond_scale( self, - x, *args, cond_scale = 1., **kwargs ): - logits = self.forward(x, *args, **kwargs) + logits = self.forward(*args, **kwargs) if cond_scale == 1: return logits - null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs) + null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward(