diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index cc133fa..ee67341 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -4,11 +4,13 @@ from pydantic import BaseModel, validator, root_validator from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar from x_clip import CLIP as XCLIP +from open_clip import list_pretrained from coca_pytorch import CoCa from dalle2_pytorch.dalle2_pytorch import ( CoCaAdapter, OpenAIClipAdapter, + OpenClipAdapter, Unet, Decoder, DiffusionPrior, @@ -117,6 +119,10 @@ class AdapterConfig(BaseModel): def create(self): if self.make == "openai": return OpenAIClipAdapter(self.model) + elif self.make == "open_clip": + pretrained = dict(list_pretrained()) + checkpoint = pretrained[self.model] + return OpenClipAdapter(name=self.model, pretrained=checkpoint) elif self.make == "x-clip": return XClipAdapter(XCLIP(**self.base_model_kwargs)) elif self.make == "coca":