mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
sketch
This commit is contained in:
@@ -36,6 +36,10 @@ def freeze_all_layers_(module):
|
||||
def unfreeze_all_layers_(module):
|
||||
set_module_requires_grad_(module, True)
|
||||
|
||||
def freeze_model_and_make_eval_(model):
|
||||
model.eval()
|
||||
freeze_all_layers_(model)
|
||||
|
||||
# diffusion prior
|
||||
|
||||
class DiffusionPrior(nn.Module):
|
||||
@@ -46,14 +50,15 @@ class DiffusionPrior(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
text,
|
||||
image
|
||||
image = None
|
||||
):
|
||||
return text
|
||||
return image_embed
|
||||
|
||||
# decoder
|
||||
|
||||
@@ -67,11 +72,14 @@ class Decoder(nn.Module):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
assert isinstance(prior, DiffusionPrior)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
image
|
||||
image,
|
||||
image_embed,
|
||||
text_embed = None # in paper, text embedding was optional for conditioning decoder
|
||||
):
|
||||
return image
|
||||
|
||||
@@ -96,4 +104,4 @@ class DALLE2(nn.Module):
|
||||
*,
|
||||
text
|
||||
):
|
||||
return text
|
||||
return image
|
||||
|
||||
Reference in New Issue
Block a user