chore: refactor read-write lock on agent (#2225)

Co-authored-by: Alice Hau <ahau@squareup.com>
This commit is contained in:
Salman Mohammed
2025-04-23 23:46:22 -03:00
committed by GitHub
parent 85e2ee3984
commit 199fa6adbc
24 changed files with 409 additions and 237 deletions

View File

@@ -33,12 +33,28 @@ CALCULATOR_TOOL = {
},
}
# Enable Extension tool definition
ENABLE_EXTENSION_TOOL = {
"name": "enable_extension",
"description": "Enable extensions to help complete tasks. Enable an extension by providing the extension name.",
"inputSchema": {
"type": "object",
"required": ["extension_name"],
"properties": {
"extension_name": {
"type": "string",
"description": "The name of the extension to enable",
},
},
},
}
# Frontend extension configuration
FRONTEND_CONFIG = {
"name": "pythonclient",
"type": "frontend",
"tools": [CALCULATOR_TOOL],
"instructions": "A calculator extension that can perform basic arithmetic operations.",
"tools": [CALCULATOR_TOOL, ENABLE_EXTENSION_TOOL],
"instructions": "A calculator extension that can perform basic arithmetic operations. Use enable extension tool to add extesions such as fetch, pdf reader, etc.",
}
@@ -47,7 +63,7 @@ async def setup_agent() -> None:
async with httpx.AsyncClient() as client:
# First create the agent
response = await client.post(
f"{GOOSE_URL}/agent",
f"{GOOSE_URL}/agent/update_provider",
json={"provider": "databricks", "model": "goose"},
headers={"X-Secret-Key": SECRET_KEY},
)
@@ -101,6 +117,55 @@ def execute_calculator(args: Dict[str, Any]) -> List[Dict[str, Any]]:
}
]
def get_tools() -> Dict[str, Any]:
with httpx.Client() as client:
response = client.get(
f"{GOOSE_URL}/agent/tools",
headers={"X-Secret-Key": SECRET_KEY},
)
response.raise_for_status()
return response.json()
def execute_enable_extension(args: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Execute the enable_extension tool.
This function fetches available extensions, finds the one with the provided extension_name,
and posts its configuration to the /extensions/add endpoint.
"""
extension = args
extension_name = extension.get("name")
# Post the extension configuration to enable it
with httpx.Client() as client:
payload = {
"type": extension.get("type"),
"name": extension.get("name"),
"cmd": extension.get("cmd"),
"args": extension.get("args"),
"envs": extension.get("envs", {}),
"timeout": extension.get("timeout"),
"bundled": extension.get("bundled"),
}
add_response = client.post(
f"{GOOSE_URL}/extensions/add",
json=payload,
headers={"Content-Type": "application/json", "X-Secret-Key": SECRET_KEY},
)
if add_response.status_code != 200:
error_text = add_response.text
return [{
"type": "text",
"text": f"Error: Failed to enable extension: {error_text}",
"annotations": None,
}]
return [{
"type": "text",
"text": f"Successfully enabled extension: {extension_name}",
"annotations": None,
}]
def submit_tool_result(tool_id: str, result: List[Dict[str, Any]]) -> None:
"""Submit the tool execution result back to Goose.
@@ -129,7 +194,7 @@ async def chat_loop() -> None:
session_id = "test-session"
# Use a client with a longer timeout for streaming
async with httpx.AsyncClient(timeout=30.0) as client:
async with httpx.AsyncClient(timeout=60.0) as client:
# Get user input
user_message = input("\nYou: ")
if user_message.lower() in ["exit", "quit"]:
@@ -152,7 +217,7 @@ async def chat_loop() -> None:
# Process the stream of responses
async with client.stream(
"POST",
f"{GOOSE_URL}/reply",
f"{GOOSE_URL}/reply", # lock
json=payload,
headers={
"X-Secret-Key": SECRET_KEY,
@@ -185,9 +250,26 @@ async def chat_loop() -> None:
elif content["type"] == "frontendToolRequest":
# Execute the tool and submit results
tool_call = content["toolCall"]["value"]
print(f"Calculator: {tool_call}")
# Execute the tool
result = execute_calculator(tool_call["arguments"])
print(f"\nTool Request: {tool_call}")
if tool_call['name'] == "calculator":
print(f"Calculator: {tool_call}")
# Execute the tool
result = execute_calculator(tool_call["arguments"])
elif tool_call['name'] == "enable_extension":
# to trigger this tool, use the instruction "use enable_extension tool with "fetch" extension name"
print(f"Enabling fetch extension")
result = execute_enable_extension(args={
"type": "stdio",
"name": "fetch",
"cmd": "uvx",
"args": ["mcp-server-fetch"],
"timeout": 300,
"bundled": False
})
listed_tools = get_tools()
print(f"\nTools after enabling extension: {listed_tools}")
# Submit the result
submit_tool_result(content["id"], result)