complete helper methods for doing condition scaling (classifier free guidance), for decoder unet and prior network

This commit is contained in:
Phil Wang
2022-04-12 16:11:16 -07:00
parent d4c8373635
commit d546a615c0

View File

@@ -179,6 +179,20 @@ class DiffusionPriorNetwork(nn.Module):
self.learned_query = nn.Parameter(torch.randn(dim)) self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = Transformer(**kwargs) self.causal_transformer = Transformer(**kwargs)
def forward_with_cond_scale(
self,
x,
*,
cond_scale = 1.,
**kwargs
):
if cond_scale == 1:
return self.forward(x, **kwargs)
logits = self.forward(x, **kwargs)
null_logits = self.forward(x, cond_prob_drop = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
self, self,
image_embed, image_embed,
@@ -371,6 +385,20 @@ class Unet(nn.Module):
nn.Conv2d(dim, out_dim, 1) nn.Conv2d(dim, out_dim, 1)
) )
def forward_with_cond_scale(
self,
x,
*,
cond_scale = 1.,
**kwargs
):
if cond_scale == 1:
return self.forward(x, **kwargs)
logits = self.forward(x, **kwargs)
null_logits = self.forward(x, cond_prob_drop = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
self, self,
x, x,
@@ -378,7 +406,7 @@ class Unet(nn.Module):
image_embed, image_embed,
time, time,
text_encodings = None, text_encodings = None,
cond_prob_drop = 0.2 cond_prob_drop = 0.
): ):
batch_size, device = image_embed.shape[0], image_embed.device batch_size, device = image_embed.shape[0], image_embed.device
t = self.time_mlp(time) if exists(self.time_mlp) else None t = self.time_mlp(time) if exists(self.time_mlp) else None