From d546a615c07a6027710829bd0e924336467bd1e2 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Apr 2022 16:11:16 -0700 Subject: [PATCH] complete helper methods for doing condition scaling (classifier free guidance), for decoder unet and prior network --- dalle2_pytorch/dalle2_pytorch.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 136c4e7..3c2aaf0 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -179,6 +179,20 @@ class DiffusionPriorNetwork(nn.Module): self.learned_query = nn.Parameter(torch.randn(dim)) self.causal_transformer = Transformer(**kwargs) + def forward_with_cond_scale( + self, + x, + *, + cond_scale = 1., + **kwargs + ): + if cond_scale == 1: + return self.forward(x, **kwargs) + + logits = self.forward(x, **kwargs) + null_logits = self.forward(x, cond_prob_drop = 1., **kwargs) + return null_logits + (logits - null_logits) * cond_scale + def forward( self, image_embed, @@ -371,6 +385,20 @@ class Unet(nn.Module): nn.Conv2d(dim, out_dim, 1) ) + def forward_with_cond_scale( + self, + x, + *, + cond_scale = 1., + **kwargs + ): + if cond_scale == 1: + return self.forward(x, **kwargs) + + logits = self.forward(x, **kwargs) + null_logits = self.forward(x, cond_prob_drop = 1., **kwargs) + return null_logits + (logits - null_logits) * cond_scale + def forward( self, x, @@ -378,7 +406,7 @@ class Unet(nn.Module): image_embed, time, text_encodings = None, - cond_prob_drop = 0.2 + cond_prob_drop = 0. ): batch_size, device = image_embed.shape[0], image_embed.device t = self.time_mlp(time) if exists(self.time_mlp) else None