mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 01:34:19 +01:00
Minor Decoder Train Script Fixes (#242)
* ensure tokenized text is on proper device * fix lpips mage distribution
This commit is contained in:
@@ -156,7 +156,7 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
|
||||
if text_embeddings[0] is None:
|
||||
# Generate text embeddings from text
|
||||
assert clip is not None, "clip is None, but text_embeddings is None"
|
||||
tokenized_texts = tokenize(txts, truncate=True)
|
||||
tokenized_texts = tokenize(txts, truncate=True).to(device=device)
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
sample_params["text_encodings"] = text_encodings
|
||||
else:
|
||||
@@ -229,8 +229,8 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=Non
|
||||
metrics["KID_std"] = kid_std.item()
|
||||
if exists(LPIPS):
|
||||
# Convert from [0, 1] to [-1, 1]
|
||||
renorm_real_images = real_images.mul(2).sub(1)
|
||||
renorm_generated_images = generated_images.mul(2).sub(1)
|
||||
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
|
||||
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
|
||||
lpips.to(device=device)
|
||||
lpips.update(renorm_real_images, renorm_generated_images)
|
||||
@@ -480,7 +480,7 @@ def train(
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
assert clip is not None
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
forward_params['text_encodings'] = text_encodings
|
||||
|
||||
Reference in New Issue
Block a user