mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 20:44:31 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f5e0aea140 | ||
|
|
5e03b7f932 | ||
|
|
62c0d321a6 |
@@ -36,6 +36,10 @@ def freeze_all_layers_(module):
|
|||||||
def unfreeze_all_layers_(module):
|
def unfreeze_all_layers_(module):
|
||||||
set_module_requires_grad_(module, True)
|
set_module_requires_grad_(module, True)
|
||||||
|
|
||||||
|
def freeze_model_and_make_eval_(model):
|
||||||
|
model.eval()
|
||||||
|
freeze_all_layers_(model)
|
||||||
|
|
||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
class DiffusionPrior(nn.Module):
|
class DiffusionPrior(nn.Module):
|
||||||
@@ -46,14 +50,15 @@ class DiffusionPrior(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
|
freeze_model_and_make_eval_(clip)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
text,
|
text,
|
||||||
image
|
image = None
|
||||||
):
|
):
|
||||||
return text
|
return image_embed
|
||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
|
|
||||||
@@ -67,11 +72,14 @@ class Decoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
assert isinstance(prior, DiffusionPrior)
|
assert isinstance(prior, DiffusionPrior)
|
||||||
|
freeze_model_and_make_eval_(clip)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
image
|
image,
|
||||||
|
image_embed,
|
||||||
|
text_embed = None # in paper, text embedding was optional for conditioning decoder
|
||||||
):
|
):
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@@ -96,4 +104,4 @@ class DALLE2(nn.Module):
|
|||||||
*,
|
*,
|
||||||
text
|
text
|
||||||
):
|
):
|
||||||
return text
|
return image
|
||||||
|
|||||||
0
dalle2_pytorch/train.py
Normal file
0
dalle2_pytorch/train.py
Normal file
13
setup.py
13
setup.py
@@ -4,7 +4,12 @@ setup(
|
|||||||
name = 'dalle2-pytorch',
|
name = 'dalle2-pytorch',
|
||||||
packages = find_packages(exclude=[]),
|
packages = find_packages(exclude=[]),
|
||||||
include_package_data = True,
|
include_package_data = True,
|
||||||
version = '0.0.1',
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'dalle2_pytorch = dalle2_pytorch.cli:main',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
version = '0.0.2',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
@@ -16,11 +21,13 @@ setup(
|
|||||||
'text to image'
|
'text to image'
|
||||||
],
|
],
|
||||||
install_requires=[
|
install_requires=[
|
||||||
|
'click',
|
||||||
'einops>=0.4',
|
'einops>=0.4',
|
||||||
'einops-exts',
|
'einops-exts',
|
||||||
'torch>=1.6',
|
'torch>=1.10',
|
||||||
|
'torchvision',
|
||||||
'x-clip>=0.4.1',
|
'x-clip>=0.4.1',
|
||||||
'yttm'
|
'youtokentome'
|
||||||
],
|
],
|
||||||
classifiers=[
|
classifiers=[
|
||||||
'Development Status :: 4 - Beta',
|
'Development Status :: 4 - Beta',
|
||||||
|
|||||||
Reference in New Issue
Block a user