From de0296106bd9a300551625cd8513cc4ee0f975f6 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 26 Apr 2022 11:42:46 -0700 Subject: [PATCH] be able to turn off warning for use of LazyLinear by passing in text embedding dimension for unet --- dalle2_pytorch/dalle2_pytorch.py | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index f3dbdb5..42e3f88 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -906,6 +906,7 @@ class Unet(nn.Module): dim, *, image_embed_dim, + text_embed_dim = None, cond_dim = None, num_image_tokens = 4, num_time_tokens = 2, @@ -959,7 +960,7 @@ class Unet(nn.Module): Rearrange('b (n d) -> b n d', n = num_image_tokens) ) if image_embed_dim != cond_dim else nn.Identity() - self.text_to_cond = nn.LazyLinear(cond_dim) + self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim) # finer control over whether to condition on image embeddings and text encodings # so one can have the latter unets in the cascading DDPMs only focus on super-resoluting diff --git a/setup.py b/setup.py index 24a42a2..36a7e2b 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.51', + version = '0.0.52', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',