mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
also consider accepting tokenizer, so dalle2 forward pass can just be invoked as DALLE2(<prompt string>)
This commit is contained in:
@@ -180,12 +180,14 @@ class DALLE2(nn.Module):
|
|||||||
*,
|
*,
|
||||||
clip,
|
clip,
|
||||||
prior,
|
prior,
|
||||||
decoder
|
decoder,
|
||||||
|
tokenizer = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip), CLIP
|
assert isinstance(clip), CLIP
|
||||||
assert isinstance(prior), DiffusionPrior
|
assert isinstance(prior), DiffusionPrior
|
||||||
assert isinstance(decoder), Decoder
|
assert isinstance(decoder), Decoder
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Reference in New Issue
Block a user