This commit is contained in:
Phil Wang
2022-04-14 08:30:07 -07:00
parent 7fb3f695d5
commit 82464d7bd3

View File

@@ -285,17 +285,16 @@ class DiffusionPriorNetwork(nn.Module):
def forward_with_cond_scale(
self,
x,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(x, *args, **kwargs)
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -675,17 +674,16 @@ class Unet(nn.Module):
def forward_with_cond_scale(
self,
x,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(x, *args, **kwargs)
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(