Compare commits

..

1 Commits

5 changed files with 4 additions and 11 deletions

View File

@@ -629,7 +629,7 @@ class NoiseScheduler(nn.Module):
def calculate_v(self, x_start, t, noise = None):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
@@ -1320,7 +1320,7 @@ class DiffusionPrior(nn.Module):
elif self.predict_x_start:
x_start = pred
else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise)
# clip x0 before maybe predicting noise

View File

@@ -4,13 +4,11 @@ from pydantic import BaseModel, validator, root_validator
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
from x_clip import CLIP as XCLIP
from open_clip import list_pretrained
from coca_pytorch import CoCa
from dalle2_pytorch.dalle2_pytorch import (
CoCaAdapter,
OpenAIClipAdapter,
OpenClipAdapter,
Unet,
Decoder,
DiffusionPrior,
@@ -119,10 +117,6 @@ class AdapterConfig(BaseModel):
def create(self):
if self.make == "openai":
return OpenAIClipAdapter(self.model)
elif self.make == "open_clip":
pretrained = dict(list_pretrained())
checkpoint = pretrained[self.model]
return OpenClipAdapter(name=self.model, pretrained=checkpoint)
elif self.make == "x-clip":
return XClipAdapter(XCLIP(**self.base_model_kwargs))
elif self.make == "coca":

View File

@@ -236,7 +236,7 @@ class DiffusionPriorTrainer(nn.Module):
)
if exists(cosine_decay_max_steps):
self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)

View File

@@ -1 +1 @@
__version__ = '1.11.2'
__version__ = '1.11.0'

View File

@@ -26,7 +26,6 @@ setup(
install_requires=[
'accelerate',
'click',
'open-clip-torch>=2.0.0,<3.0.0',
'clip-anytorch>=2.4.0',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7',