mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
per-fect
This commit is contained in:
@@ -285,17 +285,16 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
self,
|
self,
|
||||||
x,
|
|
||||||
*args,
|
*args,
|
||||||
cond_scale = 1.,
|
cond_scale = 1.,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
logits = self.forward(x, *args, **kwargs)
|
logits = self.forward(*args, **kwargs)
|
||||||
|
|
||||||
if cond_scale == 1:
|
if cond_scale == 1:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
|
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
||||||
return null_logits + (logits - null_logits) * cond_scale
|
return null_logits + (logits - null_logits) * cond_scale
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -675,17 +674,16 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
self,
|
self,
|
||||||
x,
|
|
||||||
*args,
|
*args,
|
||||||
cond_scale = 1.,
|
cond_scale = 1.,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
logits = self.forward(x, *args, **kwargs)
|
logits = self.forward(*args, **kwargs)
|
||||||
|
|
||||||
if cond_scale == 1:
|
if cond_scale == 1:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
|
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
||||||
return null_logits + (logits - null_logits) * cond_scale
|
return null_logits + (logits - null_logits) * cond_scale
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Reference in New Issue
Block a user