diff --git a/train_decoder.py b/train_decoder.py index f06dadc..22ff816 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -323,7 +323,7 @@ def train( last_snapshot = sample if next_task == 'train': - for i, (img, emb, txt) in enumerate(trainer.train_loader): + for i, (img, emb, txt) in enumerate(dataloaders["train"]): # We want to count the total number of samples across all processes sample_length_tensor[0] = len(img) all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this. @@ -358,6 +358,7 @@ def train( else: # Then we need to pass the text instead tokenized_texts = tokenize(txt, truncate=True) + assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" forward_params['text'] = tokenized_texts loss = trainer.forward(img, **forward_params, unet_number=unet) trainer.update(unet_number=unet) @@ -416,7 +417,7 @@ def train( timer = Timer() accelerator.wait_for_everyone() i = 0 - for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader + for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader val_sample_length_tensor[0] = len(img) all_samples = accelerator.gather(val_sample_length_tensor) total_samples = all_samples.sum().item()