mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
add some training fns
This commit is contained in:
@@ -2,6 +2,15 @@ import copy
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
# image related normalizations
|
||||||
|
# ddpms expect images to be in the range of -1 to 1
|
||||||
|
|
||||||
|
def normalize_img(img):
|
||||||
|
return img * 2 - 1
|
||||||
|
|
||||||
|
def unnormalize_img(normed_img):
|
||||||
|
return (normed_img + 1) * 0.5
|
||||||
|
|
||||||
# exponential moving average wrapper
|
# exponential moving average wrapper
|
||||||
|
|
||||||
class EMA(nn.Module):
|
class EMA(nn.Module):
|
||||||
|
|||||||
Reference in New Issue
Block a user