mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
pin to newer version of CLIP that returns encoded text and images, get some helper functions ready for XCLIP
This commit is contained in:
@@ -3,6 +3,48 @@ import torch.nn.functional as F
|
|||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
# use x-clip
|
||||||
|
|
||||||
|
from x_clip import CLIP
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
# for controlling freezing of CLIP
|
||||||
|
|
||||||
|
def set_module_requires_grad_(module, requires_grad):
|
||||||
|
for param in module.parameters():
|
||||||
|
param.requires_grad = requires_grad
|
||||||
|
|
||||||
|
def freeze_all_layers_(module):
|
||||||
|
set_module_requires_grad_(module, False)
|
||||||
|
|
||||||
|
def unfreeze_all_layers_(module):
|
||||||
|
set_module_requires_grad_(module, True)
|
||||||
|
|
||||||
|
# diffusion prior
|
||||||
|
|
||||||
|
class DiffusionPrior(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
# main class
|
||||||
|
|
||||||
class DALLE2(nn.Module):
|
class DALLE2(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
Reference in New Issue
Block a user