make sure lowres conditioning image is properly normalized to -1 to 1 for cascading ddpm

This commit is contained in:
Phil Wang
2022-05-14 01:23:54 -07:00
parent 3115fa17b3
commit 5d27029e98
2 changed files with 13 additions and 5 deletions

View File

@@ -1,7 +1,7 @@
import math import math
from tqdm import tqdm from tqdm import tqdm
from inspect import isfunction from inspect import isfunction
from functools import partial from functools import partial, wraps
from contextlib import contextmanager from contextlib import contextmanager
from collections import namedtuple from collections import namedtuple
from pathlib import Path from pathlib import Path
@@ -45,6 +45,14 @@ def exists(val):
def identity(t, *args, **kwargs): def identity(t, *args, **kwargs):
return t return t
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def default(val, d): def default(val, d):
if exists(val): if exists(val):
return val return val
@@ -1844,6 +1852,8 @@ class Decoder(BaseGaussianDiffusion):
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample( img = self.p_sample(
unet, unet,
@@ -1868,9 +1878,7 @@ class Decoder(BaseGaussianDiffusion):
# normalize to [-1, 1] # normalize to [-1, 1]
x_start = normalize_neg_one_to_one(x_start) x_start = normalize_neg_one_to_one(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
if exists(lowres_cond_img):
lowres_cond_img = normalize_neg_one_to_one(lowres_cond_img)
# get x_t # get x_t

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.16', version = '0.2.17',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',