diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 89d64e9..fad469a 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -284,6 +284,8 @@ class Unet(nn.Module): def __init__( self, dim, + *, + image_embed_dim, out_dim = None, dim_mults=(1, 2, 4, 8), channels = 3, @@ -338,7 +340,15 @@ class Unet(nn.Module): nn.Conv2d(dim, out_dim, 1) ) - def forward(self, x, time): + def forward( + self, + x, + *, + image_embed, + time, + text_encodings = None, + cond_prob_drop = 0.2 + ): t = self.time_mlp(time) if exists(self.time_mlp) else None hiddens = []