Files
Auto-GPT/benchmark/agbenchmark/utils/dependencies/graphs.py
Reinier van der Leer 25cc6ad6ae AGBenchmark codebase clean-up (#6650)
* refactor(benchmark): Deduplicate configuration loading logic

   - Move the configuration loading logic to a separate `load_agbenchmark_config` function in `agbenchmark/config.py` module.
   - Replace the duplicate loading logic in `conftest.py`, `generate_test.py`, `ReportManager.py`, `reports.py`, and `__main__.py` with calls to `load_agbenchmark_config` function.

* fix(benchmark): Fix type errors, linting errors, and clean up CLI validation in __main__.py

   - Fixed type errors and linting errors in `__main__.py`
   - Improved the readability of CLI argument validation by introducing a separate function for it

* refactor(benchmark): Lint and typefix app.py

   - Rearranged and cleaned up import statements
   - Fixed type errors caused by improper use of `psutil` objects
   - Simplified a number of `os.path` usages by converting to `pathlib`
   - Use `Task` and `TaskRequestBody` classes from `agent_protocol_client` instead of `.schema`

* refactor(benchmark): Replace `.agent_protocol_client` by `agent-protcol-client`, clean up schema.py

   - Remove `agbenchmark.agent_protocol_client` (an offline copy of `agent-protocol-client`).
      - Add `agent-protocol-client` as a dependency and change imports to `agent_protocol_client`.
   - Fix type annotation on `agent_api_interface.py::upload_artifacts` (`ApiClient` -> `AgentApi`).
   - Remove all unused types from schema.py (= most of them).

* refactor(benchmark): Use pathlib in agent_interface.py and agent_api_interface.py

* refactor(benchmark): Improve typing, response validation, and readability in app.py

   - Simplified response generation by leveraging type checking and conversion by FastAPI.
   - Introduced use of `HTTPException` for error responses.
   - Improved naming, formatting, and typing in `app.py::create_evaluation`.
   - Updated the docstring on `app.py::create_agent_task`.
   - Fixed return type annotations of `create_single_test` and `create_challenge` in generate_test.py.
   - Added default values to optional attributes on models in report_types_v2.py.
   - Removed unused imports in `generate_test.py`

* refactor(benchmark): Clean up logging and print statements

   - Introduced use of the `logging` library for unified logging and better readability.
   - Converted most print statements to use `logger.debug`, `logger.warning`, and `logger.error`.
   - Improved descriptiveness of log statements.
   - Removed unnecessary print statements.
   - Added log statements to unspecific and non-verbose `except` blocks.
   - Added `--debug` flag, which sets the log level to `DEBUG` and enables a more comprehensive log format.
   - Added `.utils.logging` module with `configure_logging` function to easily configure the logging library.
   - Converted raw escape sequences in `.utils.challenge` to use `colorama`.
   - Renamed `generate_test.py::generate_tests` to `load_challenges`.

* refactor(benchmark): Remove unused server.py and agent_interface.py::run_agent

   - Remove unused server.py file
   - Remove unused run_agent function from agent_interface.py

* refactor(benchmark): Clean up conftest.py

   - Fix and add type annotations
   - Rewrite docstrings
   - Disable or remove unused code
   - Fix definition of arguments and their types in `pytest_addoption`

* refactor(benchmark): Clean up generate_test.py file

   - Refactored the `create_single_test` function for clarity and readability
      - Removed unused variables
      - Made creation of `Challenge` subclasses more straightforward
      - Made bare `except` more specific
   - Renamed `Challenge.setup_challenge` method to `run_challenge`
   - Updated type hints and annotations
   - Made minor code/readability improvements in `load_challenges`
   - Added a helper function `_add_challenge_to_module` for attaching a Challenge class to the current module

* fix(benchmark): Fix and add type annotations in execute_sub_process.py

* refactor(benchmark): Simplify const determination in agent_interface.py

   - Simplify the logic that determines the value of `HELICONE_GRAPHQL_LOGS`

* fix(benchmark): Register category markers to prevent warnings

   - Use the `pytest_configure` hook to register the known challenge categories as markers. Otherwise, Pytest will raise "unknown marker" warnings at runtime.

* refactor(benchmark/challenges): Fix indentation in 4_revenue_retrieval_2/data.json

* refactor(benchmark): Update agent_api_interface.py

   - Add type annotations to `copy_agent_artifacts_into_temp_folder` function
   - Add note about broken endpoint in the `agent_protocol_client` library
   - Remove unused variable in `run_api_agent` function
   - Improve readability and resolve linting error

* feat(benchmark): Improve and centralize pathfinding

   - Search path hierarchy for applicable `agbenchmark_config`, rather than assuming it's in the current folder.
   - Create `agbenchmark.utils.path_manager` with `AGBenchmarkPathManager` and exporting a `PATH_MANAGER` const.
   - Replace path constants defined in __main__.py with usages of `PATH_MANAGER`.

* feat(benchmark/cli): Clean up and improve CLI

   - Updated commands, options, and their descriptions to be more intuitive and consistent
   - Moved slow imports into the entrypoints that use them to speed up application startup
   - Fixed type hints to match output types of Click options
   - Hid deprecated `agbenchmark start` command
   - Refactored code to improve readability and maintainability
   - Moved main entrypoint into `run` subcommand
   - Fixed `version` and `serve` subcommands
   - Added `click-default-group` package to allow using `run` implicitly (for backwards compatibility)
   - Renamed `--no_dep` to `--no-dep` for consistency
   - Fixed string formatting issues in log statements

* refactor(benchmark/config): Move AgentBenchmarkConfig and related functions to config.py

   - Move the `AgentBenchmarkConfig` class from `utils/data_types.py` to `config.py`.
   - Extract the `calculate_info_test_path` function from `utils/data_types.py` and move it to `config.py` as a private helper function `_calculate_info_test_path`.
   - Move `load_agent_benchmark_config()` to `AgentBenchmarkConfig.load()`.
   - Changed simple getter methods on `AgentBenchmarkConfig` to calculated properties.
   - Update all code references according to the changes mentioned above.

* refactor(benchmark): Fix ReportManager init parameter types and use pathlib

   - Fix the type annotation of the `benchmark_start_time` parameter in `ReportManager.__init__`, was mistyped as `str` instead of `datetime`.
   - Change the type of the `filename` parameter in the `ReportManager.__init__` method from `str` to `Path`.
   - Rename `self.filename` with `self.report_file` in `ReportManager`.
   - Change the way the report file is created, opened and saved to use the `Path` object.

* refactor(benchmark): Improve typing surrounding ChallengeData and clean up its implementation

   - Use `ChallengeData` objects instead of untyped `dict` in  app.py, generate_test.py, reports.py.
   - Remove unnecessary methods `serialize`, `get_data`, `get_json_from_path`, `deserialize` from `ChallengeData` class.
   - Remove unused methods `challenge_from_datum` and `challenge_from_test_data` from `ChallengeData class.
   - Update function signatures and annotations of `create_challenge` and `generate_single_test` functions in generate_test.py.
   - Add types to function signatures of `generate_single_call_report` and `finalize_reports` in reports.py.
   - Remove unnecessary `challenge_data` parameter (in generate_test.py) and fixture (in conftest.py).

* refactor(benchmark): Clean up generate_test.py, conftest.py and __main__.py

   - Cleaned up generate_test.py and conftest.py
      - Consolidated challenge creation logic in the `Challenge` class itself, most notably the new `Challenge.from_challenge_spec` method.
      - Moved challenge selection logic from generate_test.py to the `pytest_collection_modifyitems` hook in conftest.py.
   - Converted methods in the `Challenge` class to class methods where appropriate.
   - Improved argument handling in the `run_benchmark` function in `__main__.py`.

* refactor(benchmark/config): Merge AGBenchmarkPathManager into AgentBenchmarkConfig and reduce fragmented/global state

   - Merge the functionality of `AGBenchmarkPathManager` into `AgentBenchmarkConfig` to consolidate the configuration management.
   - Remove the `.path_manager` module containing `AGBenchmarkPathManager`.
   - Pass the `AgentBenchmarkConfig` and its attributes through function arguments to reduce global state and improve code clarity.

* feat(benchmark/serve): Configurable port for `serve` subcommand

   - Added `--port` option to `serve` subcommand to allow for specifying the port to run the API on.
   - If no `--port` option is provided, the port will default to the value specified in the `PORT` environment variable, or 8080 if not set.

* feat(benchmark/cli): Add `config` subcommand

   - Added a new subcommand `config` to the AGBenchmark CLI, to display information about the present AGBenchmark config.

* fix(benchmark): Gracefully handle incompatible challenge spec files in app.py

   - Added a check to skip deprecated challenges
   - Added logging to allow debugging of the loading process
   - Added handling of validation errors when parsing challenge spec files
   - Added missing `spec_file` attribute to `ChallengeData`

* refactor(benchmark): Move `run_benchmark` entrypoint to main.py, use it in `/reports` endpoint

   - Move `run_benchmark` and `validate_args` from __main__.py to main.py
   - Replace agbenchmark subprocess in `app.py:run_single_test` with `run_benchmark`
   - Move `get_unique_categories` from __main__.py to challenges/__init__.py
   - Move `OPTIONAL_CATEGORIES` from __main__.py to challenge.py
   - Reduce operations on updates.json (including `initialize_updates_file`) outside of API

* refactor(benchmark): Remove unused `/updates` endpoint and all related code

   - Remove `updates_json_file` attribute from `AgentBenchmarkConfig`
   - Remove `get_updates` and `_initialize_updates_file` in app.py
   - Remove `append_updates_file` and `create_update_json` functions in agent_api_interface.py
   - Remove call to `append_updates_file` in challenge.py

* refactor(benchmark/config): Clean up and update docstrings on `AgentBenchmarkConfig`

   - Add and update docstrings
   - Change base class from `BaseModel` to `BaseSettings`, allow extras for backwards compatibility
   - Make naming of path attributes on `AgentBenchmarkConfig` more consistent
   - Remove unused `agent_home_directory` attribute
   - Remove unused `workspace` attribute

* fix(benchmark): Restore mechanism to select (optional) categories in agent benchmark config

* fix(benchmark): Update agent-protocol-client to v1.1.0

   - Fixes issue with fetching task artifact listings
2024-01-02 22:23:09 +01:00

446 lines
14 KiB
Python

import json
import logging
import math
from pathlib import Path
from typing import Any, Dict, List, Tuple
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from pyvis.network import Network
from agbenchmark.generate_test import DATA_CATEGORY
from agbenchmark.utils.utils import write_pretty_json
logger = logging.getLogger(__name__)
def bezier_curve(
src: np.ndarray, ctrl: List[float], dst: np.ndarray
) -> List[np.ndarray]:
"""
Generate Bézier curve points.
Args:
- src (np.ndarray): The source point.
- ctrl (List[float]): The control point.
- dst (np.ndarray): The destination point.
Returns:
- List[np.ndarray]: The Bézier curve points.
"""
curve = []
for t in np.linspace(0, 1, num=100):
curve_point = (
np.outer((1 - t) ** 2, src)
+ 2 * np.outer((1 - t) * t, ctrl)
+ np.outer(t**2, dst)
)
curve.append(curve_point[0])
return curve
def curved_edges(
G: nx.Graph, pos: Dict[Any, Tuple[float, float]], dist: float = 0.2
) -> None:
"""
Draw curved edges for nodes on the same level.
Args:
- G (Any): The graph object.
- pos (Dict[Any, Tuple[float, float]]): Dictionary with node positions.
- dist (float, optional): Distance for curvature. Defaults to 0.2.
Returns:
- None
"""
ax = plt.gca()
for u, v, data in G.edges(data=True):
src = np.array(pos[u])
dst = np.array(pos[v])
same_level = abs(src[1] - dst[1]) < 0.01
if same_level:
control = [(src[0] + dst[0]) / 2, src[1] + dist]
curve = bezier_curve(src, control, dst)
arrow = patches.FancyArrowPatch(
posA=curve[0], # type: ignore
posB=curve[-1], # type: ignore
connectionstyle=f"arc3,rad=0.2",
color="gray",
arrowstyle="-|>",
mutation_scale=15.0,
lw=1,
shrinkA=10,
shrinkB=10,
)
ax.add_patch(arrow)
else:
ax.annotate(
"",
xy=dst,
xytext=src,
arrowprops=dict(
arrowstyle="-|>", color="gray", lw=1, shrinkA=10, shrinkB=10
),
)
def tree_layout(graph: nx.DiGraph, root_node: Any) -> Dict[Any, Tuple[float, float]]:
"""Compute positions as a tree layout centered on the root with alternating vertical shifts."""
bfs_tree = nx.bfs_tree(graph, source=root_node)
levels = {
node: depth
for node, depth in nx.single_source_shortest_path_length(
bfs_tree, root_node
).items()
}
pos = {}
max_depth = max(levels.values())
level_positions = {i: 0 for i in range(max_depth + 1)} # type: ignore
# Count the number of nodes per level to compute the width
level_count: Any = {}
for node, level in levels.items():
level_count[level] = level_count.get(level, 0) + 1
vertical_offset = (
0.07 # The amount of vertical shift per node within the same level
)
# Assign positions
for node, level in sorted(levels.items(), key=lambda x: x[1]):
total_nodes_in_level = level_count[level]
horizontal_spacing = 1.0 / (total_nodes_in_level + 1)
pos_x = (
0.5
- (total_nodes_in_level - 1) * horizontal_spacing / 2
+ level_positions[level] * horizontal_spacing
)
# Alternately shift nodes up and down within the same level
pos_y = (
-level
+ (level_positions[level] % 2) * vertical_offset
- ((level_positions[level] + 1) % 2) * vertical_offset
)
pos[node] = (pos_x, pos_y)
level_positions[level] += 1
return pos
def graph_spring_layout(
dag: nx.DiGraph, labels: Dict[Any, str], tree: bool = True
) -> None:
num_nodes = len(dag.nodes())
# Setting up the figure and axis
fig, ax = plt.subplots()
ax.axis("off") # Turn off the axis
base = 3.0
if num_nodes > 10:
base /= 1 + math.log(num_nodes)
font_size = base * 10
font_size = max(10, base * 10)
node_size = max(300, base * 1000)
if tree:
root_node = [node for node, degree in dag.in_degree() if degree == 0][0]
pos = tree_layout(dag, root_node)
else:
# Adjust k for the spring layout based on node count
k_value = 3 / math.sqrt(num_nodes)
pos = nx.spring_layout(dag, k=k_value, iterations=50)
# Draw nodes and labels
nx.draw_networkx_nodes(dag, pos, node_color="skyblue", node_size=int(node_size))
nx.draw_networkx_labels(dag, pos, labels=labels, font_size=int(font_size))
# Draw curved edges
curved_edges(dag, pos) # type: ignore
plt.tight_layout()
plt.show()
def rgb_to_hex(rgb: Tuple[float, float, float]) -> str:
return "#{:02x}{:02x}{:02x}".format(
int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)
)
def get_category_colors(categories: Dict[Any, str]) -> Dict[str, str]:
unique_categories = set(categories.values())
colormap = plt.cm.get_cmap("tab10", len(unique_categories)) # type: ignore
return {
category: rgb_to_hex(colormap(i)[:3])
for i, category in enumerate(unique_categories)
}
def graph_interactive_network(
dag: nx.DiGraph,
labels: Dict[Any, Dict[str, Any]],
html_graph_path: str = "",
) -> None:
nt = Network(notebook=True, width="100%", height="800px", directed=True)
category_colors = get_category_colors(DATA_CATEGORY)
# Add nodes and edges to the pyvis network
for node, json_data in labels.items():
label = json_data.get("name", "")
# remove the first 4 letters of label
label_without_test = label[4:]
node_id_str = node.nodeid
# Get the category for this label
category = DATA_CATEGORY.get(
label, "unknown"
) # Default to 'unknown' if label not found
# Get the color for this category
color = category_colors.get(category, "grey")
nt.add_node(
node_id_str,
label=label_without_test,
color=color,
data=json_data,
)
# Add edges to the pyvis network
for edge in dag.edges():
source_id_str = edge[0].nodeid
target_id_str = edge[1].nodeid
edge_id_str = (
f"{source_id_str}_to_{target_id_str}" # Construct a unique edge id
)
if not (source_id_str in nt.get_nodes() and target_id_str in nt.get_nodes()):
logger.warning(
f"Skipping edge {source_id_str} -> {target_id_str} due to missing nodes"
)
continue
nt.add_edge(source_id_str, target_id_str, id=edge_id_str)
# Configure physics for hierarchical layout
hierarchical_options = {
"enabled": True,
"levelSeparation": 200, # Increased vertical spacing between levels
"nodeSpacing": 250, # Increased spacing between nodes on the same level
"treeSpacing": 250, # Increased spacing between different trees (for forest)
"blockShifting": True,
"edgeMinimization": True,
"parentCentralization": True,
"direction": "UD",
"sortMethod": "directed",
}
physics_options = {
"stabilization": {
"enabled": True,
"iterations": 1000, # Default is often around 100
},
"hierarchicalRepulsion": {
"centralGravity": 0.0,
"springLength": 200, # Increased edge length
"springConstant": 0.01,
"nodeDistance": 250, # Increased minimum distance between nodes
"damping": 0.09,
},
"solver": "hierarchicalRepulsion",
"timestep": 0.5,
}
nt.options = {
"nodes": {
"font": {
"size": 20, # Increased font size for labels
"color": "black", # Set a readable font color
},
"shapeProperties": {"useBorderWithImage": True},
},
"edges": {
"length": 250, # Increased edge length
},
"physics": physics_options,
"layout": {"hierarchical": hierarchical_options},
}
# Serialize the graph to JSON and save in appropriate locations
graph_data = {"nodes": nt.nodes, "edges": nt.edges}
logger.debug(f"Generated graph data:\n{json.dumps(graph_data, indent=4)}")
# FIXME: use more reliable method to find the right location for these files.
# This will fail in all cases except if run from the root of our repo.
home_path = Path.cwd()
write_pretty_json(graph_data, home_path / "frontend" / "public" / "graph.json")
flutter_app_path = home_path.parent / "frontend" / "assets"
# Optionally, save to a file
# Sync with the flutter UI
# this literally only works in the AutoGPT repo, but this part of the code is not reached if BUILD_SKILL_TREE is false
write_pretty_json(graph_data, flutter_app_path / "tree_structure.json")
validate_skill_tree(graph_data, "")
# Extract node IDs with category "coding"
coding_tree = extract_subgraph_based_on_category(graph_data.copy(), "coding")
validate_skill_tree(coding_tree, "coding")
write_pretty_json(
coding_tree,
flutter_app_path / "coding_tree_structure.json",
)
data_tree = extract_subgraph_based_on_category(graph_data.copy(), "data")
# validate_skill_tree(data_tree, "data")
write_pretty_json(
data_tree,
flutter_app_path / "data_tree_structure.json",
)
general_tree = extract_subgraph_based_on_category(graph_data.copy(), "general")
validate_skill_tree(general_tree, "general")
write_pretty_json(
general_tree,
flutter_app_path / "general_tree_structure.json",
)
scrape_synthesize_tree = extract_subgraph_based_on_category(
graph_data.copy(), "scrape_synthesize"
)
validate_skill_tree(scrape_synthesize_tree, "scrape_synthesize")
write_pretty_json(
scrape_synthesize_tree,
flutter_app_path / "scrape_synthesize_tree_structure.json",
)
if html_graph_path:
file_path = str(Path(html_graph_path).resolve())
nt.write_html(file_path)
def extract_subgraph_based_on_category(graph, category):
"""
Extracts a subgraph that includes all nodes and edges required to reach all nodes with a specified category.
:param graph: The original graph.
:param category: The target category.
:return: Subgraph with nodes and edges required to reach the nodes with the given category.
"""
subgraph = {"nodes": [], "edges": []}
visited = set()
def reverse_dfs(node_id):
if node_id in visited:
return
visited.add(node_id)
node_data = next(node for node in graph["nodes"] if node["id"] == node_id)
# Add the node to the subgraph if it's not already present.
if node_data not in subgraph["nodes"]:
subgraph["nodes"].append(node_data)
for edge in graph["edges"]:
if edge["to"] == node_id:
if edge not in subgraph["edges"]:
subgraph["edges"].append(edge)
reverse_dfs(edge["from"])
# Identify nodes with the target category and initiate reverse DFS from them.
nodes_with_target_category = [
node["id"] for node in graph["nodes"] if category in node["data"]["category"]
]
for node_id in nodes_with_target_category:
reverse_dfs(node_id)
return subgraph
def is_circular(graph):
def dfs(node, visited, stack, parent_map):
visited.add(node)
stack.add(node)
for edge in graph["edges"]:
if edge["from"] == node:
if edge["to"] in stack:
# Detected a cycle
cycle_path = []
current = node
while current != edge["to"]:
cycle_path.append(current)
current = parent_map.get(current)
cycle_path.append(edge["to"])
cycle_path.append(node)
return cycle_path[::-1]
elif edge["to"] not in visited:
parent_map[edge["to"]] = node
cycle_path = dfs(edge["to"], visited, stack, parent_map)
if cycle_path:
return cycle_path
stack.remove(node)
return None
visited = set()
stack = set()
parent_map = {}
for node in graph["nodes"]:
node_id = node["id"]
if node_id not in visited:
cycle_path = dfs(node_id, visited, stack, parent_map)
if cycle_path:
return cycle_path
return None
def get_roots(graph):
"""
Return the roots of a graph. Roots are nodes with no incoming edges.
"""
# Create a set of all node IDs
all_nodes = {node["id"] for node in graph["nodes"]}
# Create a set of nodes with incoming edges
nodes_with_incoming_edges = {edge["to"] for edge in graph["edges"]}
# Roots are nodes that have no incoming edges
roots = all_nodes - nodes_with_incoming_edges
return list(roots)
def validate_skill_tree(graph, skill_tree_name):
"""
Validate if a given graph represents a valid skill tree and raise appropriate exceptions if not.
:param graph: A dictionary representing the graph with 'nodes' and 'edges'.
:raises: ValueError with a description of the invalidity.
"""
# Check for circularity
cycle_path = is_circular(graph)
if cycle_path:
cycle_str = " -> ".join(cycle_path)
raise ValueError(
f"{skill_tree_name} skill tree is circular! Circular path detected: {cycle_str}."
)
# Check for multiple roots
roots = get_roots(graph)
if len(roots) > 1:
raise ValueError(f"{skill_tree_name} skill tree has multiple roots: {roots}.")
elif not roots:
raise ValueError(f"{skill_tree_name} skill tree has no roots.")