mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc50c6b34e |
@@ -890,8 +890,6 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
)
|
||||
|
||||
if exists(clip):
|
||||
assert image_channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
@@ -1712,19 +1710,12 @@ class Decoder(BaseGaussianDiffusion):
|
||||
)
|
||||
|
||||
self.unconditional = unconditional
|
||||
|
||||
# text conditioning
|
||||
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# clip
|
||||
assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
assert not unconditional, 'clip must not be given if doing unconditional image training'
|
||||
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
@@ -1734,20 +1725,13 @@ class Decoder(BaseGaussianDiffusion):
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
|
||||
self.clip = clip
|
||||
|
||||
# determine image size, with image_size and image_sizes taking precedence
|
||||
|
||||
if exists(image_size) or exists(image_sizes):
|
||||
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
|
||||
image_size = default(image_size, lambda: image_sizes[-1])
|
||||
elif exists(clip):
|
||||
image_size = clip.image_size
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
else:
|
||||
raise Error('either image_size, image_sizes, or clip must be given to decoder')
|
||||
self.clip_image_size = default(image_size, lambda: image_sizes[-1])
|
||||
self.channels = channels
|
||||
|
||||
# channels
|
||||
|
||||
self.channels = channels
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
@@ -1789,7 +1773,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
# unet image sizes
|
||||
|
||||
image_sizes = default(image_sizes, (image_size,))
|
||||
image_sizes = default(image_sizes, (self.clip_image_size,))
|
||||
image_sizes = tuple(sorted(set(image_sizes)))
|
||||
|
||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||
@@ -1827,7 +1811,6 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.clip_x_start = clip_x_start
|
||||
|
||||
# normalize and unnormalize image functions
|
||||
|
||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class UnetConfig(BaseModel):
|
||||
extra = "allow"
|
||||
|
||||
class DecoderConfig(BaseModel):
|
||||
unets: ListOrTuple(UnetConfig)
|
||||
unets: Union[List[UnetConfig], Tuple[UnetConfig]]
|
||||
image_size: int = None
|
||||
image_sizes: ListOrTuple(int) = None
|
||||
channels: int = 3
|
||||
|
||||
@@ -288,7 +288,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
def save(self, path, overwrite = True):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
@@ -298,8 +298,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
model = self.diffusion_prior.state_dict(),
|
||||
version = get_pkg_version(),
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
step = self.step.item()
|
||||
)
|
||||
|
||||
if self.use_ema:
|
||||
@@ -320,7 +319,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
return
|
||||
|
||||
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
@@ -329,8 +328,6 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
def update(self):
|
||||
if exists(self.max_grad_norm):
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
@@ -452,7 +449,7 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
def save(self, path, overwrite = True):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
@@ -460,8 +457,7 @@ class DecoderTrainer(nn.Module):
|
||||
save_obj = dict(
|
||||
model = self.decoder.state_dict(),
|
||||
version = get_pkg_version(),
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
step = self.step.item()
|
||||
)
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
@@ -489,7 +485,7 @@ class DecoderTrainer(nn.Module):
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
return
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
scaler_key = f'scaler{ind}'
|
||||
@@ -504,8 +500,6 @@ class DecoderTrainer(nn.Module):
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
@property
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
Reference in New Issue
Block a user