mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
sketch
This commit is contained in:
@@ -36,6 +36,10 @@ def freeze_all_layers_(module):
|
|||||||
def unfreeze_all_layers_(module):
|
def unfreeze_all_layers_(module):
|
||||||
set_module_requires_grad_(module, True)
|
set_module_requires_grad_(module, True)
|
||||||
|
|
||||||
|
def freeze_model_and_make_eval_(model):
|
||||||
|
model.eval()
|
||||||
|
freeze_all_layers_(model)
|
||||||
|
|
||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
class DiffusionPrior(nn.Module):
|
class DiffusionPrior(nn.Module):
|
||||||
@@ -46,14 +50,15 @@ class DiffusionPrior(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
|
freeze_model_and_make_eval_(clip)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
text,
|
text,
|
||||||
image
|
image = None
|
||||||
):
|
):
|
||||||
return text
|
return image_embed
|
||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
|
|
||||||
@@ -67,11 +72,14 @@ class Decoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
assert isinstance(prior, DiffusionPrior)
|
assert isinstance(prior, DiffusionPrior)
|
||||||
|
freeze_model_and_make_eval_(clip)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
image
|
image,
|
||||||
|
image_embed,
|
||||||
|
text_embed = None # in paper, text embedding was optional for conditioning decoder
|
||||||
):
|
):
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@@ -96,4 +104,4 @@ class DALLE2(nn.Module):
|
|||||||
*,
|
*,
|
||||||
text
|
text
|
||||||
):
|
):
|
||||||
return text
|
return image
|
||||||
|
|||||||
Reference in New Issue
Block a user