Compare commits

...

3 Commits
0.0.1 ... 0.0.2

Author SHA1 Message Date
Phil Wang
f5e0aea140 get ready for CLI tool, just like stylegan2_pytorch 2022-04-12 09:57:54 -07:00
Phil Wang
5e03b7f932 get ready for all the training related classes and functions 2022-04-12 09:54:50 -07:00
Phil Wang
62c0d321a6 sketch 2022-04-12 09:39:42 -07:00
3 changed files with 22 additions and 7 deletions

View File

@@ -36,6 +36,10 @@ def freeze_all_layers_(module):
def unfreeze_all_layers_(module):
set_module_requires_grad_(module, True)
def freeze_model_and_make_eval_(model):
model.eval()
freeze_all_layers_(model)
# diffusion prior
class DiffusionPrior(nn.Module):
@@ -46,14 +50,15 @@ class DiffusionPrior(nn.Module):
):
super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
def forward(
self,
*,
text,
image
image = None
):
return text
return image_embed
# decoder
@@ -67,11 +72,14 @@ class Decoder(nn.Module):
super().__init__()
assert isinstance(clip, CLIP)
assert isinstance(prior, DiffusionPrior)
freeze_model_and_make_eval_(clip)
def forward(
self,
*,
image
image,
image_embed,
text_embed = None # in paper, text embedding was optional for conditioning decoder
):
return image
@@ -96,4 +104,4 @@ class DALLE2(nn.Module):
*,
text
):
return text
return image

0
dalle2_pytorch/train.py Normal file
View File

View File

@@ -4,7 +4,12 @@ setup(
name = 'dalle2-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.0.1',
entry_points={
'console_scripts': [
'dalle2_pytorch = dalle2_pytorch.cli:main',
],
},
version = '0.0.2',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -16,11 +21,13 @@ setup(
'text to image'
],
install_requires=[
'click',
'einops>=0.4',
'einops-exts',
'torch>=1.6',
'torch>=1.10',
'torchvision',
'x-clip>=0.4.1',
'yttm'
'youtokentome'
],
classifiers=[
'Development Status :: 4 - Beta',