mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 09:04:25 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bdc3b222f2 |
@@ -373,7 +373,7 @@ def quadratic_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
|
||||
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
|
||||
|
||||
|
||||
def sigmoid_beta_schedule(timesteps):
|
||||
|
||||
@@ -238,7 +238,7 @@ class EMA(nn.Module):
|
||||
ma_buffer.data.copy_(current_buffer.data)
|
||||
|
||||
def get_current_decay(self):
|
||||
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
|
||||
epoch = clamp(self.step.item() - self.update_after_step - 1, min = 0)
|
||||
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
||||
|
||||
if epoch <= 0:
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.6.16'
|
||||
__version__ = '0.6.14'
|
||||
|
||||
Reference in New Issue
Block a user