This commit is contained in:
Phil Wang
2022-04-12 10:12:42 -07:00
parent 2ab042b862
commit 24b428bdfc
2 changed files with 20 additions and 0 deletions

View File

@@ -12,6 +12,8 @@ It may also explore an extension of using <a href="https://huggingface.co/spaces
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
Do let me know if anyone is interested in a Jax version https://github.com/lucidrains/DALLE2-pytorch/discussions/8
## Citations ## Citations
```bibtex ```bibtex

View File

@@ -42,6 +42,24 @@ def freeze_model_and_make_eval_(model):
# diffusion prior # diffusion prior
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
):
super().__init__()
def forward(
self,
x,
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
):
return x
class DiffusionPrior(nn.Module): class DiffusionPrior(nn.Module):
def __init__( def __init__(
self, self,