This commit is contained in:
Phil Wang
2022-04-12 09:39:42 -07:00
parent 7cf1637d24
commit 62c0d321a6

View File

@@ -36,6 +36,10 @@ def freeze_all_layers_(module):
def unfreeze_all_layers_(module):
set_module_requires_grad_(module, True)
def freeze_model_and_make_eval_(model):
model.eval()
freeze_all_layers_(model)
# diffusion prior
class DiffusionPrior(nn.Module):
@@ -46,14 +50,15 @@ class DiffusionPrior(nn.Module):
):
super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
def forward(
self,
*,
text,
image
image = None
):
return text
return image_embed
# decoder
@@ -67,11 +72,14 @@ class Decoder(nn.Module):
super().__init__()
assert isinstance(clip, CLIP)
assert isinstance(prior, DiffusionPrior)
freeze_model_and_make_eval_(clip)
def forward(
self,
*,
image
image,
image_embed,
text_embed = None # in paper, text embedding was optional for conditioning decoder
):
return image
@@ -96,4 +104,4 @@ class DALLE2(nn.Module):
*,
text
):
return text
return image