Reverted to using basic dataloaders (#205)

Accelerate removes the ability to collate strings. Likely since it
cannot gather strings.
This commit is contained in:
Aidan Dempster
2022-07-12 21:22:27 -04:00
committed by GitHub
parent 349aaca56f
commit 544cdd0b29

View File

@@ -323,7 +323,7 @@ def train(
last_snapshot = sample last_snapshot = sample
if next_task == 'train': if next_task == 'train':
for i, (img, emb, txt) in enumerate(trainer.train_loader): for i, (img, emb, txt) in enumerate(dataloaders["train"]):
# We want to count the total number of samples across all processes # We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img) sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this. all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
@@ -358,6 +358,7 @@ def train(
else: else:
# Then we need to pass the text instead # Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True) tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet) loss = trainer.forward(img, **forward_params, unet_number=unet)
trainer.update(unet_number=unet) trainer.update(unet_number=unet)
@@ -416,7 +417,7 @@ def train(
timer = Timer() timer = Timer()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
i = 0 i = 0
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
val_sample_length_tensor[0] = len(img) val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor) all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item() total_samples = all_samples.sum().item()