make sure lowres conditioning image is properly normalized to -1 to 1 for cascading ddpm

This commit is contained in:
Phil Wang
2022-05-14 01:23:54 -07:00
parent 3115fa17b3
commit 5d27029e98
2 changed files with 13 additions and 5 deletions

View File

@@ -1,7 +1,7 @@
import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
from functools import partial, wraps
from contextlib import contextmanager
from collections import namedtuple
from pathlib import Path
@@ -45,6 +45,14 @@ def exists(val):
def identity(t, *args, **kwargs):
return t
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def default(val, d):
if exists(val):
return val
@@ -1844,6 +1852,8 @@ class Decoder(BaseGaussianDiffusion):
b = shape[0]
img = torch.randn(shape, device = device)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(
unet,
@@ -1868,9 +1878,7 @@ class Decoder(BaseGaussianDiffusion):
# normalize to [-1, 1]
x_start = normalize_neg_one_to_one(x_start)
if exists(lowres_cond_img):
lowres_cond_img = normalize_neg_one_to_one(lowres_cond_img)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
# get x_t

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.16',
version = '0.2.17',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',