mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
make sure lowres conditioning image is properly normalized to -1 to 1 for cascading ddpm
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
@@ -45,6 +45,14 @@ def exists(val):
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def maybe(fn):
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
if not exists(x):
|
||||
return x
|
||||
return fn(x)
|
||||
return inner
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
@@ -1844,6 +1852,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(
|
||||
unet,
|
||||
@@ -1868,9 +1878,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
# normalize to [-1, 1]
|
||||
|
||||
x_start = normalize_neg_one_to_one(x_start)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = normalize_neg_one_to_one(lowres_cond_img)
|
||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||
|
||||
# get x_t
|
||||
|
||||
|
||||
Reference in New Issue
Block a user