force einops 0.6.1 or greater and call allow_ops_in_compiled_graph

This commit is contained in:
Phil Wang
2023-04-20 14:08:52 -07:00
parent 0069857cf8
commit 00e07b7d61
3 changed files with 9 additions and 2 deletions

View File

@@ -1,3 +1,10 @@
import torch
from packaging import version
if version.parse(torch.__version__) >= version.parse('2.0.0'):
from einops._torch_specific import allow_ops_in_compiled_graph
allow_ops_in_compiled_graph()
from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter