mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix version
This commit is contained in:
@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from dalle2_pytorch.version import __version__
|
||||
from packaging import version
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -57,8 +59,7 @@ def num_to_groups(num, divisor):
|
||||
return arr
|
||||
|
||||
def get_pkg_version():
|
||||
from pkg_resources import get_distribution
|
||||
return get_distribution('dalle2_pytorch').version
|
||||
return __version__
|
||||
|
||||
# decorators
|
||||
|
||||
@@ -299,7 +300,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
scaler = self.scaler.state_dict(),
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
model = self.diffusion_prior.state_dict(),
|
||||
version = get_pkg_version(),
|
||||
version = __version__,
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
@@ -315,8 +316,8 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
loaded_obj = torch.load(str(path))
|
||||
|
||||
if get_pkg_version() != loaded_obj['version']:
|
||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
|
||||
if version.parse(__version__) != loaded_obj['version']:
|
||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
|
||||
|
||||
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
@@ -463,7 +464,7 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
save_obj = dict(
|
||||
model = self.decoder.state_dict(),
|
||||
version = get_pkg_version(),
|
||||
version = __version__,
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
@@ -486,7 +487,7 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
loaded_obj = torch.load(str(path))
|
||||
|
||||
if get_pkg_version() != loaded_obj['version']:
|
||||
if version.parse(__version__) != loaded_obj['version']:
|
||||
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
||||
|
||||
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||
|
||||
Reference in New Issue
Block a user