Compare commits

...

2 Commits

Author SHA1 Message Date
Phil Wang
b588286288 fix version 2022-05-30 11:06:34 -07:00
Phil Wang
b693e0be03 default number of resnet blocks per layer in unet to 2 (in imagen it was 3 for base 64x64) 2022-05-30 10:06:48 -07:00
5 changed files with 13 additions and 9 deletions

View File

@@ -1,3 +1,4 @@
from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer

View File

@@ -1347,7 +1347,7 @@ class Unet(nn.Module):
init_dim = None,
init_conv_kernel_size = 7,
resnet_groups = 8,
num_resnet_blocks = 1,
num_resnet_blocks = 2,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),

View File

@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version
import numpy as np
@@ -57,8 +59,7 @@ def num_to_groups(num, divisor):
return arr
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
return __version__
# decorators
@@ -299,7 +300,7 @@ class DiffusionPriorTrainer(nn.Module):
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = get_pkg_version(),
version = __version__,
step = self.step.item(),
**kwargs
)
@@ -315,8 +316,8 @@ class DiffusionPriorTrainer(nn.Module):
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
@@ -463,7 +464,7 @@ class DecoderTrainer(nn.Module):
save_obj = dict(
model = self.decoder.state_dict(),
version = get_pkg_version(),
version = __version__,
step = self.step.item(),
**kwargs
)
@@ -486,7 +487,7 @@ class DecoderTrainer(nn.Module):
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)

View File

@@ -0,0 +1 @@
__version__ = '0.6.2'

View File

@@ -1,4 +1,5 @@
from setuptools import setup, find_packages
exec(open('dalle2_pytorch/version.py').read())
setup(
name = 'dalle2-pytorch',
@@ -10,7 +11,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.6.0',
version = __version__,
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',