Compare commits

...

2 Commits

Author SHA1 Message Date
Phil Wang
d0c11b30b0 handle open clip adapter image size being a tuple 2022-09-19 10:27:14 -07:00
zion
86e2d5ba84 Minor Decoder Train Script Fixes (#242)
* ensure tokenized text is on proper device
* fix lpips mage distribution
2022-09-15 17:21:48 -07:00
3 changed files with 9 additions and 6 deletions

View File

@@ -406,7 +406,10 @@ class OpenClipAdapter(BaseClipAdapter):
@property
def image_size(self):
return self.clip.visual.image_size
image_size = self.clip.visual.image_size
if isinstance(image_size, tuple):
return max(image_size)
return image_size
@property
def image_channels(self):

View File

@@ -1 +1 @@
__version__ = '1.10.5'
__version__ = '1.10.6'

View File

@@ -156,7 +156,7 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
if text_embeddings[0] is None:
# Generate text embeddings from text
assert clip is not None, "clip is None, but text_embeddings is None"
tokenized_texts = tokenize(txts, truncate=True)
tokenized_texts = tokenize(txts, truncate=True).to(device=device)
text_embed, text_encodings = clip.embed_text(tokenized_texts)
sample_params["text_encodings"] = text_encodings
else:
@@ -229,8 +229,8 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=Non
metrics["KID_std"] = kid_std.item()
if exists(LPIPS):
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images)
@@ -480,7 +480,7 @@ def train(
else:
# Then we need to pass the text instead
assert clip is not None
tokenized_texts = tokenize(txt, truncate=True)
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
text_embed, text_encodings = clip.embed_text(tokenized_texts)
forward_params['text_encodings'] = text_encodings