Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
dc50c6b34e allow for config driven creation of clip-less diffusion prior 2022-05-22 20:13:20 -07:00
4 changed files with 9 additions and 15 deletions

View File

@@ -1712,7 +1712,7 @@ class Decoder(BaseGaussianDiffusion):
self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert self.unconditional or (exists(clip) or exists(image_size) or exists(image_sizes)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None
if exists(clip):

View File

@@ -68,7 +68,7 @@ class UnetConfig(BaseModel):
extra = "allow"
class DecoderConfig(BaseModel):
unets: ListOrTuple(UnetConfig)
unets: Union[List[UnetConfig], Tuple[UnetConfig]]
image_size: int = None
image_sizes: ListOrTuple(int) = None
channels: int = 3

View File

@@ -288,7 +288,7 @@ class DiffusionPriorTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0]))
def save(self, path, overwrite = True, **kwargs):
def save(self, path, overwrite = True):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
@@ -298,8 +298,7 @@ class DiffusionPriorTrainer(nn.Module):
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = get_pkg_version(),
step = self.step.item(),
**kwargs
step = self.step.item()
)
if self.use_ema:
@@ -320,7 +319,7 @@ class DiffusionPriorTrainer(nn.Module):
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return loaded_obj
return
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
@@ -329,8 +328,6 @@ class DiffusionPriorTrainer(nn.Module):
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
@@ -452,7 +449,7 @@ class DecoderTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True, **kwargs):
def save(self, path, overwrite = True):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
@@ -460,8 +457,7 @@ class DecoderTrainer(nn.Module):
save_obj = dict(
model = self.decoder.state_dict(),
version = get_pkg_version(),
step = self.step.item(),
**kwargs
step = self.step.item()
)
for ind in range(0, self.num_unets):
@@ -489,7 +485,7 @@ class DecoderTrainer(nn.Module):
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return loaded_obj
return
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
@@ -504,8 +500,6 @@ class DecoderTrainer(nn.Module):
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.4.10',
version = '0.4.9',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',