add some training fns

This commit is contained in:
Phil Wang
2022-04-29 09:44:55 -07:00
parent fb662a62f3
commit f4a54e475e

View File

@@ -2,6 +2,15 @@ import copy
import torch
from torch import nn
# image related normalizations
# ddpms expect images to be in the range of -1 to 1
def normalize_img(img):
return img * 2 - 1
def unnormalize_img(normed_img):
return (normed_img + 1) * 0.5
# exponential moving average wrapper
class EMA(nn.Module):