Compare commits

...

3 Commits
0.0.1 ... 0.0.2

Author SHA1 Message Date
Phil Wang
f5e0aea140 get ready for CLI tool, just like stylegan2_pytorch 2022-04-12 09:57:54 -07:00
Phil Wang
5e03b7f932 get ready for all the training related classes and functions 2022-04-12 09:54:50 -07:00
Phil Wang
62c0d321a6 sketch 2022-04-12 09:39:42 -07:00
3 changed files with 22 additions and 7 deletions

View File

@@ -36,6 +36,10 @@ def freeze_all_layers_(module):
def unfreeze_all_layers_(module): def unfreeze_all_layers_(module):
set_module_requires_grad_(module, True) set_module_requires_grad_(module, True)
def freeze_model_and_make_eval_(model):
model.eval()
freeze_all_layers_(model)
# diffusion prior # diffusion prior
class DiffusionPrior(nn.Module): class DiffusionPrior(nn.Module):
@@ -46,14 +50,15 @@ class DiffusionPrior(nn.Module):
): ):
super().__init__() super().__init__()
assert isinstance(clip, CLIP) assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
def forward( def forward(
self, self,
*, *,
text, text,
image image = None
): ):
return text return image_embed
# decoder # decoder
@@ -67,11 +72,14 @@ class Decoder(nn.Module):
super().__init__() super().__init__()
assert isinstance(clip, CLIP) assert isinstance(clip, CLIP)
assert isinstance(prior, DiffusionPrior) assert isinstance(prior, DiffusionPrior)
freeze_model_and_make_eval_(clip)
def forward( def forward(
self, self,
*, *,
image image,
image_embed,
text_embed = None # in paper, text embedding was optional for conditioning decoder
): ):
return image return image
@@ -96,4 +104,4 @@ class DALLE2(nn.Module):
*, *,
text text
): ):
return text return image

0
dalle2_pytorch/train.py Normal file
View File

View File

@@ -4,7 +4,12 @@ setup(
name = 'dalle2-pytorch', name = 'dalle2-pytorch',
packages = find_packages(exclude=[]), packages = find_packages(exclude=[]),
include_package_data = True, include_package_data = True,
version = '0.0.1', entry_points={
'console_scripts': [
'dalle2_pytorch = dalle2_pytorch.cli:main',
],
},
version = '0.0.2',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',
@@ -16,11 +21,13 @@ setup(
'text to image' 'text to image'
], ],
install_requires=[ install_requires=[
'click',
'einops>=0.4', 'einops>=0.4',
'einops-exts', 'einops-exts',
'torch>=1.6', 'torch>=1.10',
'torchvision',
'x-clip>=0.4.1', 'x-clip>=0.4.1',
'yttm' 'youtokentome'
], ],
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', 'Development Status :: 4 - Beta',