From 59b8abe09e07fb0f6c25b5a99e15c621c66d0096 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Apr 2022 12:38:56 -0700 Subject: [PATCH] prepare unet to be conditioned on image embedding, optionally text encodings, and reminder for self to build conditional dropout for classifier free guidance --- dalle2_pytorch/dalle2_pytorch.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 89d64e9..fad469a 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -284,6 +284,8 @@ class Unet(nn.Module): def __init__( self, dim, + *, + image_embed_dim, out_dim = None, dim_mults=(1, 2, 4, 8), channels = 3, @@ -338,7 +340,15 @@ class Unet(nn.Module): nn.Conv2d(dim, out_dim, 1) ) - def forward(self, x, time): + def forward( + self, + x, + *, + image_embed, + time, + text_encodings = None, + cond_prob_drop = 0.2 + ): t = self.time_mlp(time) if exists(self.time_mlp) else None hiddens = []