mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
100 lines
1.8 KiB
Python
100 lines
1.8 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn, einsum
|
|
from einops import rearrange
|
|
|
|
# use x-clip
|
|
|
|
from x_clip import CLIP
|
|
|
|
# helper functions
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
def default(val, d):
|
|
return val if exists(val) else d
|
|
|
|
def eval_decorator(fn):
|
|
def inner(model, *args, **kwargs):
|
|
was_training = model.training
|
|
model.eval()
|
|
out = fn(model, *args, **kwargs)
|
|
model.train(was_training)
|
|
return out
|
|
return inner
|
|
|
|
# for controlling freezing of CLIP
|
|
|
|
def set_module_requires_grad_(module, requires_grad):
|
|
for param in module.parameters():
|
|
param.requires_grad = requires_grad
|
|
|
|
def freeze_all_layers_(module):
|
|
set_module_requires_grad_(module, False)
|
|
|
|
def unfreeze_all_layers_(module):
|
|
set_module_requires_grad_(module, True)
|
|
|
|
# diffusion prior
|
|
|
|
class DiffusionPrior(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
clip
|
|
):
|
|
super().__init__()
|
|
assert isinstance(clip, CLIP)
|
|
|
|
def forward(
|
|
self,
|
|
*,
|
|
text,
|
|
image
|
|
):
|
|
return text
|
|
|
|
# decoder
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
clip,
|
|
prior
|
|
):
|
|
super().__init__()
|
|
assert isinstance(clip, CLIP)
|
|
assert isinstance(prior, DiffusionPrior)
|
|
|
|
def forward(
|
|
self,
|
|
*,
|
|
image
|
|
):
|
|
return image
|
|
|
|
# main class
|
|
|
|
class DALLE2(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
clip,
|
|
prior,
|
|
decoder
|
|
):
|
|
super().__init__()
|
|
assert isinstance(clip), CLIP
|
|
assert isinstance(prior), DiffusionPrior
|
|
assert isinstance(decoder), Decoder
|
|
|
|
@torch.no_grad()
|
|
def forward(
|
|
self,
|
|
*,
|
|
text
|
|
):
|
|
return text
|