pin to newer version of CLIP that returns encoded text and images, get some helper functions ready for XCLIP

This commit is contained in:
Phil Wang
2022-04-12 08:54:40 -07:00
parent 0070547e3b
commit 4ff6d021c9
2 changed files with 43 additions and 1 deletions

View File

@@ -3,6 +3,48 @@ import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
# use x-clip
from x_clip import CLIP
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad):
for param in module.parameters():
param.requires_grad = requires_grad
def freeze_all_layers_(module):
set_module_requires_grad_(module, False)
def unfreeze_all_layers_(module):
set_module_requires_grad_(module, True)
# diffusion prior
class DiffusionPrior(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
# decoder
class Decoder(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
# main class
class DALLE2(nn.Module):
def __init__(self):
super().__init__()

View File

@@ -17,7 +17,7 @@ setup(
install_requires=[
'einops>=0.4',
'torch>=1.6',
'x-clip'
'x-clip>=0.4.1'
],
classifiers=[
'Development Status :: 4 - Beta',