mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 00:34:19 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9faab59b23 | ||
|
|
5d27029e98 |
@@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from functools import partial
|
from functools import partial, wraps
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -45,6 +45,14 @@ def exists(val):
|
|||||||
def identity(t, *args, **kwargs):
|
def identity(t, *args, **kwargs):
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
def maybe(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(x):
|
||||||
|
if not exists(x):
|
||||||
|
return x
|
||||||
|
return fn(x)
|
||||||
|
return inner
|
||||||
|
|
||||||
def default(val, d):
|
def default(val, d):
|
||||||
if exists(val):
|
if exists(val):
|
||||||
return val
|
return val
|
||||||
@@ -1173,7 +1181,11 @@ class CrossAttention(nn.Module):
|
|||||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(inner_dim, dim, bias = False),
|
||||||
|
LayerNorm(dim)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, context, mask = None):
|
def forward(self, x, context, mask = None):
|
||||||
b, n, device = *x.shape[:2], x.device
|
b, n, device = *x.shape[:2], x.device
|
||||||
@@ -1844,6 +1856,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
b = shape[0]
|
b = shape[0]
|
||||||
img = torch.randn(shape, device = device)
|
img = torch.randn(shape, device = device)
|
||||||
|
|
||||||
|
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||||
img = self.p_sample(
|
img = self.p_sample(
|
||||||
unet,
|
unet,
|
||||||
@@ -1868,9 +1882,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
|
|
||||||
x_start = normalize_neg_one_to_one(x_start)
|
x_start = normalize_neg_one_to_one(x_start)
|
||||||
|
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||||
if exists(lowres_cond_img):
|
|
||||||
lowres_cond_img = normalize_neg_one_to_one(lowres_cond_img)
|
|
||||||
|
|
||||||
# get x_t
|
# get x_t
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user