mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
also consider accepting tokenizer, so dalle2 forward pass can just be invoked as DALLE2(<prompt string>)
This commit is contained in:
@@ -180,12 +180,14 @@ class DALLE2(nn.Module):
|
||||
*,
|
||||
clip,
|
||||
prior,
|
||||
decoder
|
||||
decoder,
|
||||
tokenizer = None
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip), CLIP
|
||||
assert isinstance(prior), DiffusionPrior
|
||||
assert isinstance(decoder), Decoder
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user