incorrect naming

This commit is contained in:
Phil Wang
2022-05-15 11:23:52 -07:00
parent f7cd4a0992
commit e66c7b0249

View File

@@ -80,13 +80,13 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
batch_size = len(first_tensor)
split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size)
num_chunks = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
split_kwargs_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):