mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-21 07:14:21 +01:00
soon is now
This commit is contained in:
24
sgm/modules/diffusionmodules/denoiser_weighting.py
Normal file
24
sgm/modules/diffusionmodules/denoiser_weighting.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
|
||||
class UnitWeighting:
|
||||
def __call__(self, sigma):
|
||||
return torch.ones_like(sigma, device=sigma.device)
|
||||
|
||||
|
||||
class EDMWeighting:
|
||||
def __init__(self, sigma_data=0.5):
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def __call__(self, sigma):
|
||||
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
|
||||
|
||||
|
||||
class VWeighting(EDMWeighting):
|
||||
def __init__(self):
|
||||
super().__init__(sigma_data=1.0)
|
||||
|
||||
|
||||
class EpsWeighting:
|
||||
def __call__(self, sigma):
|
||||
return sigma**-2.0
|
||||
Reference in New Issue
Block a user