diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 33d37c0..c94f8d7 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1712,7 +1712,7 @@ class Decoder(BaseGaussianDiffusion): self.unconditional = unconditional assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present' - assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' + assert self.unconditional or (exists(clip) or exists(image_size) or exists(image_sizes)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' self.clip = None if exists(clip): diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 10aa288..acdde4b 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -288,7 +288,7 @@ class DiffusionPriorTrainer(nn.Module): self.register_buffer('step', torch.tensor([0])) - def save(self, path, overwrite = True): + def save(self, path, overwrite = True, **kwargs): path = Path(path) assert not (path.exists() and not overwrite) path.parent.mkdir(parents = True, exist_ok = True) @@ -298,7 +298,8 @@ class DiffusionPriorTrainer(nn.Module): optimizer = self.optimizer.state_dict(), model = self.diffusion_prior.state_dict(), version = get_pkg_version(), - step = self.step.item() + step = self.step.item(), + **kwargs ) if self.use_ema: @@ -319,7 +320,7 @@ class DiffusionPriorTrainer(nn.Module): self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) if only_model: - return + return loaded_obj self.scaler.load_state_dict(loaded_obj['scaler']) self.optimizer.load_state_dict(loaded_obj['optimizer']) @@ -328,6 +329,8 @@ class DiffusionPriorTrainer(nn.Module): assert 'ema' in loaded_obj self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict) + return loaded_obj + def update(self): if exists(self.max_grad_norm): self.scaler.unscale_(self.optimizer) @@ -449,7 +452,7 @@ class DecoderTrainer(nn.Module): self.register_buffer('step', torch.tensor([0.])) - def save(self, path, overwrite = True): + def save(self, path, overwrite = True, **kwargs): path = Path(path) assert not (path.exists() and not overwrite) path.parent.mkdir(parents = True, exist_ok = True) @@ -457,7 +460,8 @@ class DecoderTrainer(nn.Module): save_obj = dict( model = self.decoder.state_dict(), version = get_pkg_version(), - step = self.step.item() + step = self.step.item(), + **kwargs ) for ind in range(0, self.num_unets): @@ -485,7 +489,7 @@ class DecoderTrainer(nn.Module): self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) if only_model: - return + return loaded_obj for ind in range(0, self.num_unets): scaler_key = f'scaler{ind}' @@ -500,6 +504,8 @@ class DecoderTrainer(nn.Module): assert 'ema' in loaded_obj self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) + return loaded_obj + @property def unets(self): return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) diff --git a/setup.py b/setup.py index 60ca0e9..35ded9c 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.4.9', + version = '0.4.10', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',