diff --git a/README.md b/README.md index 819861e..af33271 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ It may also explore an extension of using Join us on Discord if you are interested in helping out with the replication +Do let me know if anyone is interested in a Jax version https://github.com/lucidrains/DALLE2-pytorch/discussions/8 + ## Citations ```bibtex diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 67ca8fc..7dcf012 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -42,6 +42,24 @@ def freeze_model_and_make_eval_(model): # diffusion prior +class Transformer(nn.Module): + def __init__( + self, + *, + dim, + dim_head = 64, + heads = 8, + + ): + super().__init__() + + def forward( + self, + x, + mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings + ): + return x + class DiffusionPrior(nn.Module): def __init__( self,