mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 14:14:26 +01:00
chore: refactor read-write lock on agent (#2225)
Co-authored-by: Alice Hau <ahau@squareup.com>
This commit is contained in:
@@ -33,12 +33,28 @@ CALCULATOR_TOOL = {
|
||||
},
|
||||
}
|
||||
|
||||
# Enable Extension tool definition
|
||||
ENABLE_EXTENSION_TOOL = {
|
||||
"name": "enable_extension",
|
||||
"description": "Enable extensions to help complete tasks. Enable an extension by providing the extension name.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"required": ["extension_name"],
|
||||
"properties": {
|
||||
"extension_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the extension to enable",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Frontend extension configuration
|
||||
FRONTEND_CONFIG = {
|
||||
"name": "pythonclient",
|
||||
"type": "frontend",
|
||||
"tools": [CALCULATOR_TOOL],
|
||||
"instructions": "A calculator extension that can perform basic arithmetic operations.",
|
||||
"tools": [CALCULATOR_TOOL, ENABLE_EXTENSION_TOOL],
|
||||
"instructions": "A calculator extension that can perform basic arithmetic operations. Use enable extension tool to add extesions such as fetch, pdf reader, etc.",
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +63,7 @@ async def setup_agent() -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
# First create the agent
|
||||
response = await client.post(
|
||||
f"{GOOSE_URL}/agent",
|
||||
f"{GOOSE_URL}/agent/update_provider",
|
||||
json={"provider": "databricks", "model": "goose"},
|
||||
headers={"X-Secret-Key": SECRET_KEY},
|
||||
)
|
||||
@@ -101,6 +117,55 @@ def execute_calculator(args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
}
|
||||
]
|
||||
|
||||
def get_tools() -> Dict[str, Any]:
|
||||
with httpx.Client() as client:
|
||||
response = client.get(
|
||||
f"{GOOSE_URL}/agent/tools",
|
||||
headers={"X-Secret-Key": SECRET_KEY},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def execute_enable_extension(args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute the enable_extension tool.
|
||||
This function fetches available extensions, finds the one with the provided extension_name,
|
||||
and posts its configuration to the /extensions/add endpoint.
|
||||
"""
|
||||
extension = args
|
||||
extension_name = extension.get("name")
|
||||
|
||||
# Post the extension configuration to enable it
|
||||
with httpx.Client() as client:
|
||||
payload = {
|
||||
"type": extension.get("type"),
|
||||
"name": extension.get("name"),
|
||||
"cmd": extension.get("cmd"),
|
||||
"args": extension.get("args"),
|
||||
"envs": extension.get("envs", {}),
|
||||
"timeout": extension.get("timeout"),
|
||||
"bundled": extension.get("bundled"),
|
||||
}
|
||||
add_response = client.post(
|
||||
f"{GOOSE_URL}/extensions/add",
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json", "X-Secret-Key": SECRET_KEY},
|
||||
)
|
||||
if add_response.status_code != 200:
|
||||
error_text = add_response.text
|
||||
return [{
|
||||
"type": "text",
|
||||
"text": f"Error: Failed to enable extension: {error_text}",
|
||||
"annotations": None,
|
||||
}]
|
||||
|
||||
return [{
|
||||
"type": "text",
|
||||
"text": f"Successfully enabled extension: {extension_name}",
|
||||
"annotations": None,
|
||||
}]
|
||||
|
||||
|
||||
def submit_tool_result(tool_id: str, result: List[Dict[str, Any]]) -> None:
|
||||
"""Submit the tool execution result back to Goose.
|
||||
@@ -129,7 +194,7 @@ async def chat_loop() -> None:
|
||||
session_id = "test-session"
|
||||
|
||||
# Use a client with a longer timeout for streaming
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# Get user input
|
||||
user_message = input("\nYou: ")
|
||||
if user_message.lower() in ["exit", "quit"]:
|
||||
@@ -152,7 +217,7 @@ async def chat_loop() -> None:
|
||||
# Process the stream of responses
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{GOOSE_URL}/reply",
|
||||
f"{GOOSE_URL}/reply", # lock
|
||||
json=payload,
|
||||
headers={
|
||||
"X-Secret-Key": SECRET_KEY,
|
||||
@@ -185,9 +250,26 @@ async def chat_loop() -> None:
|
||||
elif content["type"] == "frontendToolRequest":
|
||||
# Execute the tool and submit results
|
||||
tool_call = content["toolCall"]["value"]
|
||||
print(f"Calculator: {tool_call}")
|
||||
# Execute the tool
|
||||
result = execute_calculator(tool_call["arguments"])
|
||||
print(f"\nTool Request: {tool_call}")
|
||||
|
||||
if tool_call['name'] == "calculator":
|
||||
print(f"Calculator: {tool_call}")
|
||||
# Execute the tool
|
||||
result = execute_calculator(tool_call["arguments"])
|
||||
elif tool_call['name'] == "enable_extension":
|
||||
# to trigger this tool, use the instruction "use enable_extension tool with "fetch" extension name"
|
||||
print(f"Enabling fetch extension")
|
||||
result = execute_enable_extension(args={
|
||||
"type": "stdio",
|
||||
"name": "fetch",
|
||||
"cmd": "uvx",
|
||||
"args": ["mcp-server-fetch"],
|
||||
"timeout": 300,
|
||||
"bundled": False
|
||||
})
|
||||
listed_tools = get_tools()
|
||||
print(f"\nTools after enabling extension: {listed_tools}")
|
||||
|
||||
|
||||
# Submit the result
|
||||
submit_tool_result(content["id"], result)
|
||||
|
||||
Reference in New Issue
Block a user