mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
incorrect naming
This commit is contained in:
@@ -80,13 +80,13 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
||||
|
||||
batch_size = len(first_tensor)
|
||||
split_size = default(split_size, batch_size)
|
||||
chunk_size = ceil(batch_size / split_size)
|
||||
num_chunks = ceil(batch_size / split_size)
|
||||
|
||||
dict_len = len(kwargs)
|
||||
dict_keys = kwargs.keys()
|
||||
split_kwargs_index = len_all_args - dict_len
|
||||
|
||||
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
|
||||
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
|
||||
chunk_sizes = tuple(map(len, split_all_args[0]))
|
||||
|
||||
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
|
||||
|
||||
Reference in New Issue
Block a user