From b588286288b8b0bb3e3a7132ddc3a3643629e456 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 30 May 2022 11:06:34 -0700 Subject: [PATCH] fix version --- dalle2_pytorch/__init__.py | 1 + dalle2_pytorch/trainer.py | 15 ++++++++------- dalle2_pytorch/version.py | 1 + setup.py | 3 ++- 4 files changed, 12 insertions(+), 8 deletions(-) create mode 100644 dalle2_pytorch/version.py diff --git a/dalle2_pytorch/__init__.py b/dalle2_pytorch/__init__.py index 7394cc1..f1d88d6 100644 --- a/dalle2_pytorch/__init__.py +++ b/dalle2_pytorch/__init__.py @@ -1,3 +1,4 @@ +from dalle2_pytorch.version import __version__ from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 4d1ab07..6c5d1f5 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior from dalle2_pytorch.optimizer import get_optimizer +from dalle2_pytorch.version import __version__ +from packaging import version import numpy as np @@ -57,8 +59,7 @@ def num_to_groups(num, divisor): return arr def get_pkg_version(): - from pkg_resources import get_distribution - return get_distribution('dalle2_pytorch').version + return __version__ # decorators @@ -299,7 +300,7 @@ class DiffusionPriorTrainer(nn.Module): scaler = self.scaler.state_dict(), optimizer = self.optimizer.state_dict(), model = self.diffusion_prior.state_dict(), - version = get_pkg_version(), + version = __version__, step = self.step.item(), **kwargs ) @@ -315,8 +316,8 @@ class DiffusionPriorTrainer(nn.Module): loaded_obj = torch.load(str(path)) - if get_pkg_version() != loaded_obj['version']: - print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}') + if version.parse(__version__) != loaded_obj['version']: + print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}') self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict) self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) @@ -463,7 +464,7 @@ class DecoderTrainer(nn.Module): save_obj = dict( model = self.decoder.state_dict(), - version = get_pkg_version(), + version = __version__, step = self.step.item(), **kwargs ) @@ -486,7 +487,7 @@ class DecoderTrainer(nn.Module): loaded_obj = torch.load(str(path)) - if get_pkg_version() != loaded_obj['version']: + if version.parse(__version__) != loaded_obj['version']: print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}') self.decoder.load_state_dict(loaded_obj['model'], strict = strict) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py new file mode 100644 index 0000000..aece342 --- /dev/null +++ b/dalle2_pytorch/version.py @@ -0,0 +1 @@ +__version__ = '0.6.2' diff --git a/setup.py b/setup.py index fc32ff8..f83eb09 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ from setuptools import setup, find_packages +exec(open('dalle2_pytorch/version.py').read()) setup( name = 'dalle2-pytorch', @@ -10,7 +11,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.6.1', + version = __version__, license='MIT', description = 'DALL-E 2', author = 'Phil Wang',