Validate skill tree so the UI never breaks (#5306)

Validate skill tree to prevent it from breaking the UI

Signed-off-by: Merwane Hamadi <merwanehamadi@gmail.com>
This commit is contained in:
merwanehamadi
2023-09-22 17:32:05 -07:00
committed by GitHub
parent ecc8d9430c
commit fa9fc18e22
7 changed files with 321 additions and 5 deletions

View File

@@ -1,6 +1,7 @@
{
"category": [
"retrieval"
"retrieval",
"general"
],
"cutoff": 60,
"dependencies": [

View File

@@ -283,23 +283,27 @@ def graph_interactive_network(
# Sync with the flutter UI
# this literally only works in the AutoGPT repo, but this part of the code is not reached if BUILD_SKILL_TREE is false
write_pretty_json(graph_data, flutter_app_path / "tree_structure.json")
validate_skill_tree(graph_data, "")
import json
# Extract node IDs with category "coding"
coding_tree = extract_subgraph_based_on_category(graph_data.copy(), "coding")
validate_skill_tree(coding_tree, "coding")
write_pretty_json(
coding_tree,
flutter_app_path / "coding_tree_structure.json",
)
data_tree = extract_subgraph_based_on_category(graph_data.copy(), "data")
# validate_skill_tree(data_tree, "data")
write_pretty_json(
data_tree,
flutter_app_path / "data_tree_structure.json",
)
general_tree = extract_subgraph_based_on_category(graph_data.copy(), "general")
validate_skill_tree(general_tree, "general")
write_pretty_json(
general_tree,
flutter_app_path / "general_tree_structure.json",
@@ -308,6 +312,7 @@ def graph_interactive_network(
scrape_synthesize_tree = extract_subgraph_based_on_category(
graph_data.copy(), "scrape_synthesize"
)
validate_skill_tree(scrape_synthesize_tree, "scrape_synthesize")
write_pretty_json(
scrape_synthesize_tree,
flutter_app_path / "scrape_synthesize_tree_structure.json",
@@ -360,3 +365,78 @@ def extract_subgraph_based_on_category(graph, category):
reverse_dfs(node_id)
return subgraph
def is_circular(graph):
def dfs(node, visited, stack, parent_map):
visited.add(node)
stack.add(node)
for edge in graph["edges"]:
if edge["from"] == node:
if edge["to"] in stack:
# Detected a cycle
cycle_path = []
current = node
while current != edge["to"]:
cycle_path.append(current)
current = parent_map.get(current)
cycle_path.append(edge["to"])
cycle_path.append(node)
return cycle_path[::-1]
elif edge["to"] not in visited:
parent_map[edge["to"]] = node
cycle_path = dfs(edge["to"], visited, stack, parent_map)
if cycle_path:
return cycle_path
stack.remove(node)
return None
visited = set()
stack = set()
parent_map = {}
for node in graph["nodes"]:
node_id = node["id"]
if node_id not in visited:
cycle_path = dfs(node_id, visited, stack, parent_map)
if cycle_path:
return cycle_path
return None
def get_roots(graph):
"""
Return the roots of a graph. Roots are nodes with no incoming edges.
"""
# Create a set of all node IDs
all_nodes = {node["id"] for node in graph["nodes"]}
# Create a set of nodes with incoming edges
nodes_with_incoming_edges = {edge["to"] for edge in graph["edges"]}
# Roots are nodes that have no incoming edges
roots = all_nodes - nodes_with_incoming_edges
return list(roots)
def validate_skill_tree(graph, skill_tree_name):
"""
Validate if a given graph represents a valid skill tree and raise appropriate exceptions if not.
:param graph: A dictionary representing the graph with 'nodes' and 'edges'.
:raises: ValueError with a description of the invalidity.
"""
# Check for circularity
cycle_path = is_circular(graph)
if cycle_path:
cycle_str = " -> ".join(cycle_path)
raise ValueError(
f"{skill_tree_name} skill tree is circular! Circular path detected: {cycle_str}."
)
# Check for multiple roots
roots = get_roots(graph)
if len(roots) > 1:
raise ValueError(f"{skill_tree_name} skill tree has multiple roots: {roots}.")
elif not roots:
raise ValueError(f"{skill_tree_name} skill tree has no roots.")

View File

@@ -411,7 +411,8 @@
"color": "grey",
"data": {
"category": [
"retrieval"
"retrieval",
"general"
],
"cutoff": 60,
"dependencies": [

View File

@@ -0,0 +1,57 @@
from agbenchmark.utils.dependencies.graphs import get_roots
def test_get_roots():
graph = {
"nodes": [
{"id": "A", "data": {"category": []}},
{"id": "B", "data": {"category": []}},
{"id": "C", "data": {"category": []}},
{"id": "D", "data": {"category": []}},
],
"edges": [
{"from": "A", "to": "B"},
{"from": "B", "to": "C"},
],
}
result = get_roots(graph)
assert set(result) == {
"A",
"D",
}, f"Expected roots to be 'A' and 'D', but got {result}"
def test_no_roots():
fully_connected_graph = {
"nodes": [
{"id": "A", "data": {"category": []}},
{"id": "B", "data": {"category": []}},
{"id": "C", "data": {"category": []}},
],
"edges": [
{"from": "A", "to": "B"},
{"from": "B", "to": "C"},
{"from": "C", "to": "A"},
],
}
result = get_roots(fully_connected_graph)
assert not result, "Expected no roots, but found some"
# def test_no_rcoots():
# fully_connected_graph = {
# "nodes": [
# {"id": "A", "data": {"category": []}},
# {"id": "B", "data": {"category": []}},
# {"id": "C", "data": {"category": []}},
# ],
# "edges": [
# {"from": "A", "to": "B"},
# {"from": "D", "to": "C"},
# ],
# }
#
# result = get_roots(fully_connected_graph)
# assert set(result) == {"A"}, f"Expected roots to be 'A', but got {result}"

View File

@@ -0,0 +1,47 @@
from agbenchmark.utils.dependencies.graphs import is_circular
def test_is_circular():
cyclic_graph = {
"nodes": [
{"id": "A", "data": {"category": []}},
{"id": "B", "data": {"category": []}},
{"id": "C", "data": {"category": []}},
{"id": "D", "data": {"category": []}}, # New node
],
"edges": [
{"from": "A", "to": "B"},
{"from": "B", "to": "C"},
{"from": "C", "to": "D"},
{"from": "D", "to": "A"}, # This edge creates a cycle
],
}
result = is_circular(cyclic_graph)
assert result is not None, "Expected a cycle, but none was detected"
assert all(
(
(result[i], result[i + 1])
in [(x["from"], x["to"]) for x in cyclic_graph["edges"]]
)
for i in range(len(result) - 1)
), "The detected cycle path is not part of the graph's edges"
def test_is_not_circular():
acyclic_graph = {
"nodes": [
{"id": "A", "data": {"category": []}},
{"id": "B", "data": {"category": []}},
{"id": "C", "data": {"category": []}},
{"id": "D", "data": {"category": []}}, # New node
],
"edges": [
{"from": "A", "to": "B"},
{"from": "B", "to": "C"},
{"from": "C", "to": "D"},
# No back edge from D to any node, so it remains acyclic
],
}
assert is_circular(acyclic_graph) is None, "Detected a cycle in an acyclic graph"

View File

@@ -1,4 +1,133 @@
{
"edges": [],
"nodes": []
"edges": [
{
"arrows": "to",
"from": "agbenchmark/generate_test.py::TestSearch::test_method[challenge_data0]",
"id": "agbenchmark/generate_test.py::TestSearch::test_method[challenge_data0]_to_agbenchmark/generate_test.py::TestBasicRetrieval::test_method[challenge_data0]",
"to": "agbenchmark/generate_test.py::TestBasicRetrieval::test_method[challenge_data0]"
},
{
"arrows": "to",
"from": "agbenchmark/generate_test.py::TestWriteFile::test_method[challenge_data0]",
"id": "agbenchmark/generate_test.py::TestWriteFile::test_method[challenge_data0]_to_agbenchmark/generate_test.py::TestSearch::test_method[challenge_data0]",
"to": "agbenchmark/generate_test.py::TestSearch::test_method[challenge_data0]"
}
],
"nodes": [
{
"color": "grey",
"data": {
"category": [
"retrieval",
"general"
],
"cutoff": 60,
"dependencies": [
"TestSearch"
],
"eval_id": "525001ed-8b45-4405-9e56-ce4423314294",
"ground": {
"answer": "\u00a325.89",
"eval": {
"type": "file"
},
"files": [
".txt"
],
"should_contain": [
"25.89"
],
"should_not_contain": []
},
"info": {
"description": "Specifies specific website to retrieve website from.",
"difficulty": "basic",
"side_effects": []
},
"name": "TestBasicRetrieval",
"task": "Write the price of the book in this url 'http://books.toscrape.com/catalogue/meditations_33/index.html' to a .txt file."
},
"id": "agbenchmark/generate_test.py::TestBasicRetrieval::test_method[challenge_data0]",
"label": "BasicRetrieval",
"shape": "dot"
},
{
"color": "grey",
"data": {
"category": [
"interface"
],
"cutoff": 120,
"dependencies": [
"TestWriteFile"
],
"eval_id": "6390114a-531d-4743-a51b-50ba6ff8da43",
"ground": {
"answer": "This is a Heading\nThis is a paragraph.",
"eval": {
"type": "file"
},
"files": [
".txt"
],
"should_contain": [
"Heading",
"paragraph"
],
"should_not_contain": [
"The",
"the"
]
},
"info": {
"description": "Tests if an llm can search",
"difficulty": "interface",
"side_effects": [
""
]
},
"name": "TestSearch",
"task": "Open 'https://silennaihin.com/random/plain.html' and paste all of the text on the page in a .txt file"
},
"id": "agbenchmark/generate_test.py::TestSearch::test_method[challenge_data0]",
"label": "Search",
"shape": "dot"
},
{
"color": "grey",
"data": {
"category": [
"interface"
],
"cutoff": 60,
"dependencies": [],
"eval_id": "81b64bf9-2b6a-4ac8-bcd2-8bfe36244ac0",
"ground": {
"answer": "The word 'Washington', printed to a .txt file named anything",
"eval": {
"type": "file"
},
"files": [
".txt"
],
"should_contain": [
"Washington"
],
"should_not_contain": []
},
"info": {
"description": "Tests the agents ability to write to a file",
"difficulty": "interface",
"side_effects": [
""
]
},
"name": "TestWriteFile",
"task": "Write the word 'Washington' to a .txt file"
},
"id": "agbenchmark/generate_test.py::TestWriteFile::test_method[challenge_data0]",
"label": "WriteFile",
"shape": "dot"
}
]
}

View File

@@ -411,7 +411,8 @@
"color": "grey",
"data": {
"category": [
"retrieval"
"retrieval",
"general"
],
"cutoff": 60,
"dependencies": [