mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-23 11:34:20 +01:00
complete helper methods for doing condition scaling (classifier free guidance), for decoder unet and prior network
This commit is contained in:
@@ -179,6 +179,20 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||||
self.causal_transformer = Transformer(**kwargs)
|
self.causal_transformer = Transformer(**kwargs)
|
||||||
|
|
||||||
|
def forward_with_cond_scale(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
*,
|
||||||
|
cond_scale = 1.,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if cond_scale == 1:
|
||||||
|
return self.forward(x, **kwargs)
|
||||||
|
|
||||||
|
logits = self.forward(x, **kwargs)
|
||||||
|
null_logits = self.forward(x, cond_prob_drop = 1., **kwargs)
|
||||||
|
return null_logits + (logits - null_logits) * cond_scale
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
image_embed,
|
image_embed,
|
||||||
@@ -371,6 +385,20 @@ class Unet(nn.Module):
|
|||||||
nn.Conv2d(dim, out_dim, 1)
|
nn.Conv2d(dim, out_dim, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward_with_cond_scale(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
*,
|
||||||
|
cond_scale = 1.,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if cond_scale == 1:
|
||||||
|
return self.forward(x, **kwargs)
|
||||||
|
|
||||||
|
logits = self.forward(x, **kwargs)
|
||||||
|
null_logits = self.forward(x, cond_prob_drop = 1., **kwargs)
|
||||||
|
return null_logits + (logits - null_logits) * cond_scale
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@@ -378,7 +406,7 @@ class Unet(nn.Module):
|
|||||||
image_embed,
|
image_embed,
|
||||||
time,
|
time,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
cond_prob_drop = 0.2
|
cond_prob_drop = 0.
|
||||||
):
|
):
|
||||||
batch_size, device = image_embed.shape[0], image_embed.device
|
batch_size, device = image_embed.shape[0], image_embed.device
|
||||||
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
||||||
|
|||||||
Reference in New Issue
Block a user