diff --git a/dalle2_pytorch/cli.py b/dalle2_pytorch/cli.py index aca5267..a2a6650 100644 --- a/dalle2_pytorch/cli.py +++ b/dalle2_pytorch/cli.py @@ -1,6 +1,7 @@ import click import torch import torchvision.transforms as T +from functools import reduce from pathlib import Path from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior