mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-22 19:14:20 +01:00
readme
This commit is contained in:
@@ -42,6 +42,24 @@ def freeze_model_and_make_eval_(model):
|
||||
|
||||
# diffusion prior
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
||||
):
|
||||
return x
|
||||
|
||||
class DiffusionPrior(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user