Disable unproven paragraph flattening in split_text()

This commit is contained in:
Reinier van der Leer
2023-08-29 02:06:47 +02:00
parent 6fac2386c7
commit d2cc22c698

View File

@@ -1,7 +1,7 @@
"""Text processing functions"""
import logging
from math import ceil
from typing import Iterator, Optional, Sequence
from typing import Iterator, Optional, Sequence, TypeVar
import spacy
import tiktoken
@@ -13,14 +13,18 @@ from autogpt.llm.utils import count_string_tokens, create_chat_completion
logger = logging.getLogger(__name__)
T = TypeVar("T")
def batch(iterable: Sequence, max_batch_length: int, overlap: int = 0):
def batch(
sequence: Sequence[T], max_batch_length: int, overlap: int = 0
) -> Iterator[Sequence[T]]:
"""Batch data from iterable into slices of length N. The last batch may be shorter."""
# batched('ABCDEFG', 3) --> ABC DEF G
if max_batch_length < 1:
raise ValueError("n must be at least one")
for i in range(0, len(iterable), max_batch_length - overlap):
yield iterable[i : i + max_batch_length]
for i in range(0, len(sequence), max_batch_length - overlap):
yield sequence[i : i + max_batch_length]
def _max_chunk_length(model: str, max: Optional[int] = None) -> int:
@@ -42,7 +46,7 @@ def chunk_content(
content: str,
for_model: str,
max_chunk_length: Optional[int] = None,
with_overlap=True,
with_overlap: bool = True,
) -> Iterator[tuple[str, int]]:
"""Split content into chunks of approximately equal token length."""
@@ -155,7 +159,7 @@ def split_text(
text: str,
for_model: str,
config: Config,
with_overlap=True,
with_overlap: bool = True,
max_chunk_length: Optional[int] = None,
) -> Iterator[tuple[str, int]]:
"""Split text into chunks of sentences, with each chunk not exceeding the maximum length
@@ -176,8 +180,6 @@ def split_text(
max_length = _max_chunk_length(for_model, max_chunk_length)
# flatten paragraphs to improve performance
text = text.replace("\n", " ")
text_length = count_string_tokens(text, for_model)
if text_length < max_length: