diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 0b5ade0..67ca8fc 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -36,6 +36,10 @@ def freeze_all_layers_(module): def unfreeze_all_layers_(module): set_module_requires_grad_(module, True) +def freeze_model_and_make_eval_(model): + model.eval() + freeze_all_layers_(model) + # diffusion prior class DiffusionPrior(nn.Module): @@ -46,14 +50,15 @@ class DiffusionPrior(nn.Module): ): super().__init__() assert isinstance(clip, CLIP) + freeze_model_and_make_eval_(clip) def forward( self, *, text, - image + image = None ): - return text + return image_embed # decoder @@ -67,11 +72,14 @@ class Decoder(nn.Module): super().__init__() assert isinstance(clip, CLIP) assert isinstance(prior, DiffusionPrior) + freeze_model_and_make_eval_(clip) def forward( self, *, - image + image, + image_embed, + text_embed = None # in paper, text embedding was optional for conditioning decoder ): return image @@ -96,4 +104,4 @@ class DALLE2(nn.Module): *, text ): - return text + return image