diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index c35bfe0..f6cebe3 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -2,6 +2,15 @@ import copy import torch from torch import nn +# image related normalizations +# ddpms expect images to be in the range of -1 to 1 + +def normalize_img(img): + return img * 2 - 1 + +def unnormalize_img(normed_img): + return (normed_img + 1) * 0.5 + # exponential moving average wrapper class EMA(nn.Module):