Compare commits

...

1 Commits

2 changed files with 13 additions and 5 deletions

View File

@@ -1,7 +1,7 @@
import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
from functools import partial, wraps
from contextlib import contextmanager
from collections import namedtuple
from pathlib import Path
@@ -45,6 +45,14 @@ def exists(val):
def identity(t, *args, **kwargs):
return t
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def default(val, d):
if exists(val):
return val
@@ -1844,6 +1852,8 @@ class Decoder(BaseGaussianDiffusion):
b = shape[0]
img = torch.randn(shape, device = device)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(
unet,
@@ -1868,9 +1878,7 @@ class Decoder(BaseGaussianDiffusion):
# normalize to [-1, 1]
x_start = normalize_neg_one_to_one(x_start)
if exists(lowres_cond_img):
lowres_cond_img = normalize_neg_one_to_one(lowres_cond_img)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
# get x_t

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.16',
version = '0.2.17',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',