From 544cdd0b296accad1c1b501582394507939f1174 Mon Sep 17 00:00:00 2001 From: Aidan Dempster Date: Tue, 12 Jul 2022 21:22:27 -0400 Subject: [PATCH] Reverted to using basic dataloaders (#205) Accelerate removes the ability to collate strings. Likely since it cannot gather strings. --- train_decoder.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train_decoder.py b/train_decoder.py index f06dadc..22ff816 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -323,7 +323,7 @@ def train( last_snapshot = sample if next_task == 'train': - for i, (img, emb, txt) in enumerate(trainer.train_loader): + for i, (img, emb, txt) in enumerate(dataloaders["train"]): # We want to count the total number of samples across all processes sample_length_tensor[0] = len(img) all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this. @@ -358,6 +358,7 @@ def train( else: # Then we need to pass the text instead tokenized_texts = tokenize(txt, truncate=True) + assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" forward_params['text'] = tokenized_texts loss = trainer.forward(img, **forward_params, unet_number=unet) trainer.update(unet_number=unet) @@ -416,7 +417,7 @@ def train( timer = Timer() accelerator.wait_for_everyone() i = 0 - for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader + for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader val_sample_length_tensor[0] = len(img) all_samples = accelerator.gather(val_sample_length_tensor) total_samples = all_samples.sum().item()