refactor: Reduce breakage in vector memory module and split MemoryItem class

- Refactored the `MemoryItem` class in the `autogpt.memory.vector.memory_item` module to improve code organization and readability.
- Split the `MemoryItem` class into two separate classes: `MemoryItem` and `MemoryItemFactory`.
- Modified the `get_embedding` function in the `autogpt.memory.vector.utils` module to accept an `EmbeddingModelProvider` for creating embeddings.
- Updated the usage of the `get_embedding` function in the `MemoryItem` class to pass the `embedding_provider` parameter.
- Updated the imports in the affected modules.
This commit is contained in:
Reinier van der Leer
2023-12-02 15:32:34 +01:00
parent 6d439f4f63
commit ef35702c4b
5 changed files with 188 additions and 153 deletions

View File

@@ -14,7 +14,7 @@ from autogpt.agents.agent import Agent
from autogpt.agents.utils.exceptions import DuplicateOperationError
from autogpt.command_decorator import command
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.memory.vector import MemoryItem, VectorMemory
from autogpt.memory.vector import MemoryItemFactory, VectorMemory
from .decorators import sanitize_path_arg
from .file_operations_utils import read_textual_file
@@ -184,7 +184,7 @@ def ingest_file(
content = read_file(filename)
# TODO: differentiate between different types of files
file_memory = MemoryItem.from_text_file(content, filename)
file_memory = MemoryItemFactory.from_text_file(content, filename)
logger.debug(f"Created memory: {file_memory.dump(True)}")
memory.add(file_memory)

View File

@@ -1,6 +1,6 @@
from autogpt.config import Config
from .memory_item import MemoryItem, MemoryItemRelevance
from .memory_item import MemoryItem, MemoryItemFactory, MemoryItemRelevance
from .providers.base import VectorMemoryProvider as VectorMemory
from .providers.json_file import JSONFileMemory
from .providers.no_memory import NoMemory
@@ -144,6 +144,7 @@ def get_supported_memory_backends():
__all__ = [
"get_memory",
"MemoryItem",
"MemoryItemFactory",
"MemoryItemRelevance",
"JSONFileMemory",
"NoMemory",

View File

@@ -9,7 +9,11 @@ import numpy as np
from pydantic import BaseModel
from autogpt.config import Config
from autogpt.core.resource.model_providers import ChatMessage
from autogpt.core.resource.model_providers import (
ChatMessage,
ChatModelProvider,
EmbeddingModelProvider,
)
from autogpt.processing.text import chunk_content, split_text, summarize_text
from .utils import Embedding, get_embedding
@@ -19,7 +23,6 @@ logger = logging.getLogger(__name__)
MemoryDocType = Literal["webpage", "text_file", "code_file", "agent_history"]
# FIXME: implement validators instead of allowing arbitrary types
class MemoryItem(BaseModel, arbitrary_types_allowed=True):
"""Memory object containing raw content as well as embeddings"""
@@ -34,141 +37,11 @@ class MemoryItem(BaseModel, arbitrary_types_allowed=True):
def relevance_for(self, query: str, e_query: Embedding | None = None):
return MemoryItemRelevance.of(self, query, e_query)
@staticmethod
def from_text(
text: str,
source_type: MemoryDocType,
config: Config,
metadata: dict = {},
how_to_summarize: str | None = None,
question_for_summary: str | None = None,
):
logger.debug(f"Memorizing text:\n{'-'*32}\n{text}\n{'-'*32}\n")
# Fix encoding, e.g. removing unicode surrogates (see issue #778)
text = ftfy.fix_text(text)
# FIXME: needs ModelProvider
chunks = [
chunk
for chunk, _ in (
split_text(text, config.embedding_model, config)
if source_type != "code_file"
else chunk_content(text, config.embedding_model)
)
]
logger.debug("Chunks: " + str(chunks))
chunk_summaries = [
summary
for summary, _ in [
summarize_text(
text_chunk,
config,
instruction=how_to_summarize,
question=question_for_summary,
)
for text_chunk in chunks
]
]
logger.debug("Chunk summaries: " + str(chunk_summaries))
e_chunks = get_embedding(chunks, config)
summary = (
chunk_summaries[0]
if len(chunks) == 1
else summarize_text(
"\n\n".join(chunk_summaries),
config,
instruction=how_to_summarize,
question=question_for_summary,
)[0]
)
logger.debug("Total summary: " + summary)
# TODO: investigate search performance of weighted average vs summary
# e_average = np.average(e_chunks, axis=0, weights=[len(c) for c in chunks])
e_summary = get_embedding(summary, config)
metadata["source_type"] = source_type
return MemoryItem(
raw_content=text,
summary=summary,
chunks=chunks,
chunk_summaries=chunk_summaries,
e_summary=e_summary,
e_chunks=e_chunks,
metadata=metadata,
)
@staticmethod
def from_text_file(content: str, path: str, config: Config):
return MemoryItem.from_text(content, "text_file", config, {"location": path})
@staticmethod
def from_code_file(content: str, path: str):
# TODO: implement tailored code memories
return MemoryItem.from_text(content, "code_file", {"location": path})
@staticmethod
def from_ai_action(ai_message: ChatMessage, result_message: ChatMessage):
# The result_message contains either user feedback
# or the result of the command specified in ai_message
if ai_message.role != "assistant":
raise ValueError(f"Invalid role on 'ai_message': {ai_message.role}")
result = (
result_message.content
if result_message.content.startswith("Command")
else "None"
)
user_input = (
result_message.content
if result_message.content.startswith("Human feedback")
else "None"
)
memory_content = (
f"Assistant Reply: {ai_message.content}"
"\n\n"
f"Result: {result}"
"\n\n"
f"Human Feedback: {user_input}"
)
return MemoryItem.from_text(
text=memory_content,
source_type="agent_history",
how_to_summarize=(
"if possible, also make clear the link between the command in the"
" assistant's response and the command result. "
"Do not mention the human feedback if there is none.",
),
)
@staticmethod
def from_webpage(
content: str, url: str, config: Config, question: str | None = None
):
return MemoryItem.from_text(
text=content,
source_type="webpage",
config=config,
metadata={"location": url},
question_for_summary=question,
)
def dump(self, calculate_length=False) -> str:
if calculate_length:
token_length = self.llm_provider.count_tokens(
self.raw_content, Config().embedding_model
)
n_chunks = len(self.e_chunks)
return f"""
=============== MemoryItem ===============
Size: {f'{token_length} tokens in ' if calculate_length else ''}{n_chunks} chunks
Size: {n_chunks} chunks
Metadata: {json.dumps(self.metadata, indent=2)}
---------------- SUMMARY -----------------
{self.summary}
@@ -203,6 +76,152 @@ Metadata: {json.dumps(self.metadata, indent=2)}
)
class MemoryItemFactory:
def __init__(
self,
llm_provider: ChatModelProvider,
embedding_provider: EmbeddingModelProvider,
):
self.llm_provider = llm_provider
self.embedding_provider = embedding_provider
async def from_text(
self,
text: str,
source_type: MemoryDocType,
config: Config,
metadata: dict = {},
how_to_summarize: str | None = None,
question_for_summary: str | None = None,
):
logger.debug(f"Memorizing text:\n{'-'*32}\n{text}\n{'-'*32}\n")
# Fix encoding, e.g. removing unicode surrogates (see issue #778)
text = ftfy.fix_text(text)
# FIXME: needs ModelProvider
chunks = [
chunk
for chunk, _ in (
split_text(
text=text,
config=config,
max_chunk_length=1000, # arbitrary, but shorter ~= better
tokenizer=self.llm_provider.get_tokenizer(config.fast_llm),
)
if source_type != "code_file"
# TODO: chunk code based on structure/outline
else chunk_content(
content=text,
max_chunk_length=1000,
tokenizer=self.llm_provider.get_tokenizer(config.fast_llm),
)
)
]
logger.debug("Chunks: " + str(chunks))
chunk_summaries = [
summary
for summary, _ in [
await summarize_text(
text=text_chunk,
instruction=how_to_summarize,
question=question_for_summary,
llm_provider=self.llm_provider,
config=config,
)
for text_chunk in chunks
]
]
logger.debug("Chunk summaries: " + str(chunk_summaries))
e_chunks = get_embedding(chunks, config, self.embedding_provider)
summary = (
chunk_summaries[0]
if len(chunks) == 1
else (
await summarize_text(
text="\n\n".join(chunk_summaries),
instruction=how_to_summarize,
question=question_for_summary,
llm_provider=self.llm_provider,
config=config,
)
)[0]
)
logger.debug("Total summary: " + summary)
# TODO: investigate search performance of weighted average vs summary
# e_average = np.average(e_chunks, axis=0, weights=[len(c) for c in chunks])
e_summary = get_embedding(summary, config, self.embedding_provider)
metadata["source_type"] = source_type
return MemoryItem(
raw_content=text,
summary=summary,
chunks=chunks,
chunk_summaries=chunk_summaries,
e_summary=e_summary,
e_chunks=e_chunks,
metadata=metadata,
)
def from_text_file(self, content: str, path: str, config: Config):
return self.from_text(content, "text_file", config, {"location": path})
def from_code_file(self, content: str, path: str):
# TODO: implement tailored code memories
return self.from_text(content, "code_file", {"location": path})
def from_ai_action(self, ai_message: ChatMessage, result_message: ChatMessage):
# The result_message contains either user feedback
# or the result of the command specified in ai_message
if ai_message.role != "assistant":
raise ValueError(f"Invalid role on 'ai_message': {ai_message.role}")
result = (
result_message.content
if result_message.content.startswith("Command")
else "None"
)
user_input = (
result_message.content
if result_message.content.startswith("Human feedback")
else "None"
)
memory_content = (
f"Assistant Reply: {ai_message.content}"
"\n\n"
f"Result: {result}"
"\n\n"
f"Human Feedback: {user_input}"
)
return self.from_text(
text=memory_content,
source_type="agent_history",
how_to_summarize=(
"if possible, also make clear the link between the command in the"
" assistant's response and the command result. "
"Do not mention the human feedback if there is none.",
),
)
def from_webpage(
self, content: str, url: str, config: Config, question: str | None = None
):
return self.from_text(
text=content,
source_type="webpage",
config=config,
metadata={"location": url},
question_for_summary=question,
)
class MemoryItemRelevance(BaseModel):
"""
Class that encapsulates memory relevance search functionality and data.

View File

@@ -5,6 +5,7 @@ from typing import Any, Sequence, overload
import numpy as np
from autogpt.config import Config
from autogpt.core.resource.model_providers import EmbeddingModelProvider
logger = logging.getLogger(__name__)
@@ -16,23 +17,32 @@ TText = Sequence[int]
@overload
def get_embedding(input: str | TText, config: Config) -> Embedding:
async def get_embedding(
input: str | TText, config: Config, embedding_provider: EmbeddingModelProvider
) -> Embedding:
...
@overload
def get_embedding(input: list[str] | list[TText], config: Config) -> list[Embedding]:
async def get_embedding(
input: list[str] | list[TText],
config: Config,
embedding_provider: EmbeddingModelProvider,
) -> list[Embedding]:
...
def get_embedding(
input: str | TText | list[str] | list[TText], config: Config
async def get_embedding(
input: str | TText | list[str] | list[TText],
config: Config,
embedding_provider: EmbeddingModelProvider,
) -> Embedding | list[Embedding]:
"""Get an embedding from the ada model.
Args:
input: Input text to get embeddings for, encoded as a string or array of tokens.
Multiple inputs may be given as a list of strings or token arrays.
embedding_provider: The provider to create embeddings.
Returns:
List[float]: The embedding.
@@ -52,25 +62,30 @@ def get_embedding(
return [_get_embedding_with_plugin(i, config) for i in input]
model = config.embedding_model
kwargs = {"model": model}
kwargs.update(config.get_openai_credentials(model))
logger.debug(
f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}"
f" with model '{model}'"
+ (f" via Azure deployment '{kwargs['engine']}'" if config.use_azure else "")
)
embeddings = embedding_provider.create_embedding(
input,
**kwargs,
).data
if not multiple:
return embeddings[0]["embedding"]
embeddings = sorted(embeddings, key=lambda x: x["index"])
return [d["embedding"] for d in embeddings]
return (
await embedding_provider.create_embedding(
text=input,
model_name=model,
embedding_parser=lambda e: e,
)
).embedding
else:
embeddings = []
for text in input:
result = await embedding_provider.create_embedding(
text=text,
model_name=model,
embedding_parser=lambda e: e,
)
embeddings.append(result.embedding)
return embeddings
def _get_embedding_with_plugin(text: str, config: Config) -> Embedding:

View File

@@ -26,7 +26,7 @@ def mock_MemoryItem_from_text(
mocker: MockerFixture, mock_embedding: Embedding, config: Config
):
mocker.patch.object(
file_ops.MemoryItem,
file_ops.MemoryItemFactory,
"from_text",
new=lambda content, source_type, config, metadata: MemoryItem(
raw_content=content,