diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py index 4d08389..1c073fd 100644 --- a/dalle2_pytorch/vqgan_vae.py +++ b/dalle2_pytorch/vqgan_vae.py @@ -68,8 +68,8 @@ def group_dict_by_key(cond, d): return_val[ind][key] = d[key] return (*return_val,) -def string_begins_with(prefix, str): - return str.startswith(prefix) +def string_begins_with(prefix, string_input): + return string_input.startswith(prefix) def group_by_key_prefix(prefix, d): return group_dict_by_key(partial(string_begins_with, prefix), d)