From f4a54e475ee2b25db62e5d21a43a5292d9a6bf5e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 29 Apr 2022 09:44:55 -0700 Subject: [PATCH] add some training fns --- dalle2_pytorch/train.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index c35bfe0..f6cebe3 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -2,6 +2,15 @@ import copy import torch from torch import nn +# image related normalizations +# ddpms expect images to be in the range of -1 to 1 + +def normalize_img(img): + return img * 2 - 1 + +def unnormalize_img(normed_img): + return (normed_img + 1) * 0.5 + # exponential moving average wrapper class EMA(nn.Module):