mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-02-16 19:54:25 +01:00
refactor: Reduce breakage in vector memory module and split MemoryItem class
- Refactored the `MemoryItem` class in the `autogpt.memory.vector.memory_item` module to improve code organization and readability. - Split the `MemoryItem` class into two separate classes: `MemoryItem` and `MemoryItemFactory`. - Modified the `get_embedding` function in the `autogpt.memory.vector.utils` module to accept an `EmbeddingModelProvider` for creating embeddings. - Updated the usage of the `get_embedding` function in the `MemoryItem` class to pass the `embedding_provider` parameter. - Updated the imports in the affected modules.
This commit is contained in:
@@ -14,7 +14,7 @@ from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import DuplicateOperationError
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.memory.vector import MemoryItem, VectorMemory
|
||||
from autogpt.memory.vector import MemoryItemFactory, VectorMemory
|
||||
|
||||
from .decorators import sanitize_path_arg
|
||||
from .file_operations_utils import read_textual_file
|
||||
@@ -184,7 +184,7 @@ def ingest_file(
|
||||
content = read_file(filename)
|
||||
|
||||
# TODO: differentiate between different types of files
|
||||
file_memory = MemoryItem.from_text_file(content, filename)
|
||||
file_memory = MemoryItemFactory.from_text_file(content, filename)
|
||||
logger.debug(f"Created memory: {file_memory.dump(True)}")
|
||||
memory.add(file_memory)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from autogpt.config import Config
|
||||
|
||||
from .memory_item import MemoryItem, MemoryItemRelevance
|
||||
from .memory_item import MemoryItem, MemoryItemFactory, MemoryItemRelevance
|
||||
from .providers.base import VectorMemoryProvider as VectorMemory
|
||||
from .providers.json_file import JSONFileMemory
|
||||
from .providers.no_memory import NoMemory
|
||||
@@ -144,6 +144,7 @@ def get_supported_memory_backends():
|
||||
__all__ = [
|
||||
"get_memory",
|
||||
"MemoryItem",
|
||||
"MemoryItemFactory",
|
||||
"MemoryItemRelevance",
|
||||
"JSONFileMemory",
|
||||
"NoMemory",
|
||||
|
||||
@@ -9,7 +9,11 @@ import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.resource.model_providers import ChatMessage
|
||||
from autogpt.core.resource.model_providers import (
|
||||
ChatMessage,
|
||||
ChatModelProvider,
|
||||
EmbeddingModelProvider,
|
||||
)
|
||||
from autogpt.processing.text import chunk_content, split_text, summarize_text
|
||||
|
||||
from .utils import Embedding, get_embedding
|
||||
@@ -19,7 +23,6 @@ logger = logging.getLogger(__name__)
|
||||
MemoryDocType = Literal["webpage", "text_file", "code_file", "agent_history"]
|
||||
|
||||
|
||||
# FIXME: implement validators instead of allowing arbitrary types
|
||||
class MemoryItem(BaseModel, arbitrary_types_allowed=True):
|
||||
"""Memory object containing raw content as well as embeddings"""
|
||||
|
||||
@@ -34,141 +37,11 @@ class MemoryItem(BaseModel, arbitrary_types_allowed=True):
|
||||
def relevance_for(self, query: str, e_query: Embedding | None = None):
|
||||
return MemoryItemRelevance.of(self, query, e_query)
|
||||
|
||||
@staticmethod
|
||||
def from_text(
|
||||
text: str,
|
||||
source_type: MemoryDocType,
|
||||
config: Config,
|
||||
metadata: dict = {},
|
||||
how_to_summarize: str | None = None,
|
||||
question_for_summary: str | None = None,
|
||||
):
|
||||
logger.debug(f"Memorizing text:\n{'-'*32}\n{text}\n{'-'*32}\n")
|
||||
|
||||
# Fix encoding, e.g. removing unicode surrogates (see issue #778)
|
||||
text = ftfy.fix_text(text)
|
||||
|
||||
# FIXME: needs ModelProvider
|
||||
chunks = [
|
||||
chunk
|
||||
for chunk, _ in (
|
||||
split_text(text, config.embedding_model, config)
|
||||
if source_type != "code_file"
|
||||
else chunk_content(text, config.embedding_model)
|
||||
)
|
||||
]
|
||||
logger.debug("Chunks: " + str(chunks))
|
||||
|
||||
chunk_summaries = [
|
||||
summary
|
||||
for summary, _ in [
|
||||
summarize_text(
|
||||
text_chunk,
|
||||
config,
|
||||
instruction=how_to_summarize,
|
||||
question=question_for_summary,
|
||||
)
|
||||
for text_chunk in chunks
|
||||
]
|
||||
]
|
||||
logger.debug("Chunk summaries: " + str(chunk_summaries))
|
||||
|
||||
e_chunks = get_embedding(chunks, config)
|
||||
|
||||
summary = (
|
||||
chunk_summaries[0]
|
||||
if len(chunks) == 1
|
||||
else summarize_text(
|
||||
"\n\n".join(chunk_summaries),
|
||||
config,
|
||||
instruction=how_to_summarize,
|
||||
question=question_for_summary,
|
||||
)[0]
|
||||
)
|
||||
logger.debug("Total summary: " + summary)
|
||||
|
||||
# TODO: investigate search performance of weighted average vs summary
|
||||
# e_average = np.average(e_chunks, axis=0, weights=[len(c) for c in chunks])
|
||||
e_summary = get_embedding(summary, config)
|
||||
|
||||
metadata["source_type"] = source_type
|
||||
|
||||
return MemoryItem(
|
||||
raw_content=text,
|
||||
summary=summary,
|
||||
chunks=chunks,
|
||||
chunk_summaries=chunk_summaries,
|
||||
e_summary=e_summary,
|
||||
e_chunks=e_chunks,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_text_file(content: str, path: str, config: Config):
|
||||
return MemoryItem.from_text(content, "text_file", config, {"location": path})
|
||||
|
||||
@staticmethod
|
||||
def from_code_file(content: str, path: str):
|
||||
# TODO: implement tailored code memories
|
||||
return MemoryItem.from_text(content, "code_file", {"location": path})
|
||||
|
||||
@staticmethod
|
||||
def from_ai_action(ai_message: ChatMessage, result_message: ChatMessage):
|
||||
# The result_message contains either user feedback
|
||||
# or the result of the command specified in ai_message
|
||||
|
||||
if ai_message.role != "assistant":
|
||||
raise ValueError(f"Invalid role on 'ai_message': {ai_message.role}")
|
||||
|
||||
result = (
|
||||
result_message.content
|
||||
if result_message.content.startswith("Command")
|
||||
else "None"
|
||||
)
|
||||
user_input = (
|
||||
result_message.content
|
||||
if result_message.content.startswith("Human feedback")
|
||||
else "None"
|
||||
)
|
||||
memory_content = (
|
||||
f"Assistant Reply: {ai_message.content}"
|
||||
"\n\n"
|
||||
f"Result: {result}"
|
||||
"\n\n"
|
||||
f"Human Feedback: {user_input}"
|
||||
)
|
||||
|
||||
return MemoryItem.from_text(
|
||||
text=memory_content,
|
||||
source_type="agent_history",
|
||||
how_to_summarize=(
|
||||
"if possible, also make clear the link between the command in the"
|
||||
" assistant's response and the command result. "
|
||||
"Do not mention the human feedback if there is none.",
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_webpage(
|
||||
content: str, url: str, config: Config, question: str | None = None
|
||||
):
|
||||
return MemoryItem.from_text(
|
||||
text=content,
|
||||
source_type="webpage",
|
||||
config=config,
|
||||
metadata={"location": url},
|
||||
question_for_summary=question,
|
||||
)
|
||||
|
||||
def dump(self, calculate_length=False) -> str:
|
||||
if calculate_length:
|
||||
token_length = self.llm_provider.count_tokens(
|
||||
self.raw_content, Config().embedding_model
|
||||
)
|
||||
n_chunks = len(self.e_chunks)
|
||||
return f"""
|
||||
=============== MemoryItem ===============
|
||||
Size: {f'{token_length} tokens in ' if calculate_length else ''}{n_chunks} chunks
|
||||
Size: {n_chunks} chunks
|
||||
Metadata: {json.dumps(self.metadata, indent=2)}
|
||||
---------------- SUMMARY -----------------
|
||||
{self.summary}
|
||||
@@ -203,6 +76,152 @@ Metadata: {json.dumps(self.metadata, indent=2)}
|
||||
)
|
||||
|
||||
|
||||
class MemoryItemFactory:
|
||||
def __init__(
|
||||
self,
|
||||
llm_provider: ChatModelProvider,
|
||||
embedding_provider: EmbeddingModelProvider,
|
||||
):
|
||||
self.llm_provider = llm_provider
|
||||
self.embedding_provider = embedding_provider
|
||||
|
||||
async def from_text(
|
||||
self,
|
||||
text: str,
|
||||
source_type: MemoryDocType,
|
||||
config: Config,
|
||||
metadata: dict = {},
|
||||
how_to_summarize: str | None = None,
|
||||
question_for_summary: str | None = None,
|
||||
):
|
||||
logger.debug(f"Memorizing text:\n{'-'*32}\n{text}\n{'-'*32}\n")
|
||||
|
||||
# Fix encoding, e.g. removing unicode surrogates (see issue #778)
|
||||
text = ftfy.fix_text(text)
|
||||
|
||||
# FIXME: needs ModelProvider
|
||||
chunks = [
|
||||
chunk
|
||||
for chunk, _ in (
|
||||
split_text(
|
||||
text=text,
|
||||
config=config,
|
||||
max_chunk_length=1000, # arbitrary, but shorter ~= better
|
||||
tokenizer=self.llm_provider.get_tokenizer(config.fast_llm),
|
||||
)
|
||||
if source_type != "code_file"
|
||||
# TODO: chunk code based on structure/outline
|
||||
else chunk_content(
|
||||
content=text,
|
||||
max_chunk_length=1000,
|
||||
tokenizer=self.llm_provider.get_tokenizer(config.fast_llm),
|
||||
)
|
||||
)
|
||||
]
|
||||
logger.debug("Chunks: " + str(chunks))
|
||||
|
||||
chunk_summaries = [
|
||||
summary
|
||||
for summary, _ in [
|
||||
await summarize_text(
|
||||
text=text_chunk,
|
||||
instruction=how_to_summarize,
|
||||
question=question_for_summary,
|
||||
llm_provider=self.llm_provider,
|
||||
config=config,
|
||||
)
|
||||
for text_chunk in chunks
|
||||
]
|
||||
]
|
||||
logger.debug("Chunk summaries: " + str(chunk_summaries))
|
||||
|
||||
e_chunks = get_embedding(chunks, config, self.embedding_provider)
|
||||
|
||||
summary = (
|
||||
chunk_summaries[0]
|
||||
if len(chunks) == 1
|
||||
else (
|
||||
await summarize_text(
|
||||
text="\n\n".join(chunk_summaries),
|
||||
instruction=how_to_summarize,
|
||||
question=question_for_summary,
|
||||
llm_provider=self.llm_provider,
|
||||
config=config,
|
||||
)
|
||||
)[0]
|
||||
)
|
||||
logger.debug("Total summary: " + summary)
|
||||
|
||||
# TODO: investigate search performance of weighted average vs summary
|
||||
# e_average = np.average(e_chunks, axis=0, weights=[len(c) for c in chunks])
|
||||
e_summary = get_embedding(summary, config, self.embedding_provider)
|
||||
|
||||
metadata["source_type"] = source_type
|
||||
|
||||
return MemoryItem(
|
||||
raw_content=text,
|
||||
summary=summary,
|
||||
chunks=chunks,
|
||||
chunk_summaries=chunk_summaries,
|
||||
e_summary=e_summary,
|
||||
e_chunks=e_chunks,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def from_text_file(self, content: str, path: str, config: Config):
|
||||
return self.from_text(content, "text_file", config, {"location": path})
|
||||
|
||||
def from_code_file(self, content: str, path: str):
|
||||
# TODO: implement tailored code memories
|
||||
return self.from_text(content, "code_file", {"location": path})
|
||||
|
||||
def from_ai_action(self, ai_message: ChatMessage, result_message: ChatMessage):
|
||||
# The result_message contains either user feedback
|
||||
# or the result of the command specified in ai_message
|
||||
|
||||
if ai_message.role != "assistant":
|
||||
raise ValueError(f"Invalid role on 'ai_message': {ai_message.role}")
|
||||
|
||||
result = (
|
||||
result_message.content
|
||||
if result_message.content.startswith("Command")
|
||||
else "None"
|
||||
)
|
||||
user_input = (
|
||||
result_message.content
|
||||
if result_message.content.startswith("Human feedback")
|
||||
else "None"
|
||||
)
|
||||
memory_content = (
|
||||
f"Assistant Reply: {ai_message.content}"
|
||||
"\n\n"
|
||||
f"Result: {result}"
|
||||
"\n\n"
|
||||
f"Human Feedback: {user_input}"
|
||||
)
|
||||
|
||||
return self.from_text(
|
||||
text=memory_content,
|
||||
source_type="agent_history",
|
||||
how_to_summarize=(
|
||||
"if possible, also make clear the link between the command in the"
|
||||
" assistant's response and the command result. "
|
||||
"Do not mention the human feedback if there is none.",
|
||||
),
|
||||
)
|
||||
|
||||
def from_webpage(
|
||||
self, content: str, url: str, config: Config, question: str | None = None
|
||||
):
|
||||
return self.from_text(
|
||||
text=content,
|
||||
source_type="webpage",
|
||||
config=config,
|
||||
metadata={"location": url},
|
||||
question_for_summary=question,
|
||||
)
|
||||
|
||||
|
||||
class MemoryItemRelevance(BaseModel):
|
||||
"""
|
||||
Class that encapsulates memory relevance search functionality and data.
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Sequence, overload
|
||||
import numpy as np
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.resource.model_providers import EmbeddingModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,23 +17,32 @@ TText = Sequence[int]
|
||||
|
||||
|
||||
@overload
|
||||
def get_embedding(input: str | TText, config: Config) -> Embedding:
|
||||
async def get_embedding(
|
||||
input: str | TText, config: Config, embedding_provider: EmbeddingModelProvider
|
||||
) -> Embedding:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def get_embedding(input: list[str] | list[TText], config: Config) -> list[Embedding]:
|
||||
async def get_embedding(
|
||||
input: list[str] | list[TText],
|
||||
config: Config,
|
||||
embedding_provider: EmbeddingModelProvider,
|
||||
) -> list[Embedding]:
|
||||
...
|
||||
|
||||
|
||||
def get_embedding(
|
||||
input: str | TText | list[str] | list[TText], config: Config
|
||||
async def get_embedding(
|
||||
input: str | TText | list[str] | list[TText],
|
||||
config: Config,
|
||||
embedding_provider: EmbeddingModelProvider,
|
||||
) -> Embedding | list[Embedding]:
|
||||
"""Get an embedding from the ada model.
|
||||
|
||||
Args:
|
||||
input: Input text to get embeddings for, encoded as a string or array of tokens.
|
||||
Multiple inputs may be given as a list of strings or token arrays.
|
||||
embedding_provider: The provider to create embeddings.
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding.
|
||||
@@ -52,25 +62,30 @@ def get_embedding(
|
||||
return [_get_embedding_with_plugin(i, config) for i in input]
|
||||
|
||||
model = config.embedding_model
|
||||
kwargs = {"model": model}
|
||||
kwargs.update(config.get_openai_credentials(model))
|
||||
|
||||
logger.debug(
|
||||
f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}"
|
||||
f" with model '{model}'"
|
||||
+ (f" via Azure deployment '{kwargs['engine']}'" if config.use_azure else "")
|
||||
)
|
||||
|
||||
embeddings = embedding_provider.create_embedding(
|
||||
input,
|
||||
**kwargs,
|
||||
).data
|
||||
|
||||
if not multiple:
|
||||
return embeddings[0]["embedding"]
|
||||
|
||||
embeddings = sorted(embeddings, key=lambda x: x["index"])
|
||||
return [d["embedding"] for d in embeddings]
|
||||
return (
|
||||
await embedding_provider.create_embedding(
|
||||
text=input,
|
||||
model_name=model,
|
||||
embedding_parser=lambda e: e,
|
||||
)
|
||||
).embedding
|
||||
else:
|
||||
embeddings = []
|
||||
for text in input:
|
||||
result = await embedding_provider.create_embedding(
|
||||
text=text,
|
||||
model_name=model,
|
||||
embedding_parser=lambda e: e,
|
||||
)
|
||||
embeddings.append(result.embedding)
|
||||
return embeddings
|
||||
|
||||
|
||||
def _get_embedding_with_plugin(text: str, config: Config) -> Embedding:
|
||||
|
||||
@@ -26,7 +26,7 @@ def mock_MemoryItem_from_text(
|
||||
mocker: MockerFixture, mock_embedding: Embedding, config: Config
|
||||
):
|
||||
mocker.patch.object(
|
||||
file_ops.MemoryItem,
|
||||
file_ops.MemoryItemFactory,
|
||||
"from_text",
|
||||
new=lambda content, source_type, config, metadata: MemoryItem(
|
||||
raw_content=content,
|
||||
|
||||
Reference in New Issue
Block a user