mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 20:54:22 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b588286288 |
@@ -1,3 +1,4 @@
|
|||||||
|
from dalle2_pytorch.version import __version__
|
||||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
|
|||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
|
from dalle2_pytorch.version import __version__
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -57,8 +59,7 @@ def num_to_groups(num, divisor):
|
|||||||
return arr
|
return arr
|
||||||
|
|
||||||
def get_pkg_version():
|
def get_pkg_version():
|
||||||
from pkg_resources import get_distribution
|
return __version__
|
||||||
return get_distribution('dalle2_pytorch').version
|
|
||||||
|
|
||||||
# decorators
|
# decorators
|
||||||
|
|
||||||
@@ -299,7 +300,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
scaler = self.scaler.state_dict(),
|
scaler = self.scaler.state_dict(),
|
||||||
optimizer = self.optimizer.state_dict(),
|
optimizer = self.optimizer.state_dict(),
|
||||||
model = self.diffusion_prior.state_dict(),
|
model = self.diffusion_prior.state_dict(),
|
||||||
version = get_pkg_version(),
|
version = __version__,
|
||||||
step = self.step.item(),
|
step = self.step.item(),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
@@ -315,8 +316,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
loaded_obj = torch.load(str(path))
|
loaded_obj = torch.load(str(path))
|
||||||
|
|
||||||
if get_pkg_version() != loaded_obj['version']:
|
if version.parse(__version__) != loaded_obj['version']:
|
||||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
|
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
|
||||||
|
|
||||||
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
|
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
@@ -463,7 +464,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
save_obj = dict(
|
save_obj = dict(
|
||||||
model = self.decoder.state_dict(),
|
model = self.decoder.state_dict(),
|
||||||
version = get_pkg_version(),
|
version = __version__,
|
||||||
step = self.step.item(),
|
step = self.step.item(),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
@@ -486,7 +487,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
loaded_obj = torch.load(str(path))
|
loaded_obj = torch.load(str(path))
|
||||||
|
|
||||||
if get_pkg_version() != loaded_obj['version']:
|
if version.parse(__version__) != loaded_obj['version']:
|
||||||
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
||||||
|
|
||||||
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
|
|||||||
1
dalle2_pytorch/version.py
Normal file
1
dalle2_pytorch/version.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.6.2'
|
||||||
3
setup.py
3
setup.py
@@ -1,4 +1,5 @@
|
|||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
exec(open('dalle2_pytorch/version.py').read())
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name = 'dalle2-pytorch',
|
name = 'dalle2-pytorch',
|
||||||
@@ -10,7 +11,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.6.1',
|
version = __version__,
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
Reference in New Issue
Block a user