Add early abort functions.

This commit is contained in:
BillSchumacher
2023-04-16 23:39:33 -05:00
parent 3715ebc7eb
commit fbd4e06df5
4 changed files with 20 additions and 0 deletions

View File

@@ -180,6 +180,8 @@ class Agent:
result = f"Human feedback: {user_input}" result = f"Human feedback: {user_input}"
else: else:
for plugin in cfg.plugins: for plugin in cfg.plugins:
if not plugin.can_handle_pre_command():
continue
command_name, arguments = plugin.pre_command( command_name, arguments = plugin.pre_command(
command_name, arguments command_name, arguments
) )
@@ -192,6 +194,8 @@ class Agent:
result = f"Command {command_name} returned: " f"{command_result}" result = f"Command {command_name} returned: " f"{command_result}"
for plugin in cfg.plugins: for plugin in cfg.plugins:
if not plugin.can_handle_post_command():
continue
result = plugin.post_command(command_name, result) result = plugin.post_command(command_name, result)
if self.next_action_count > 0: if self.next_action_count > 0:
self.next_action_count -= 1 self.next_action_count -= 1

View File

@@ -31,6 +31,8 @@ class AgentManager(metaclass=Singleton):
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
] ]
for plugin in self.cfg.plugins: for plugin in self.cfg.plugins:
if not plugin.can_handle_pre_instruction():
continue
plugin_messages = plugin.pre_instruction(messages) plugin_messages = plugin.pre_instruction(messages)
if plugin_messages: if plugin_messages:
for plugin_message in plugin_messages: for plugin_message in plugin_messages:
@@ -46,6 +48,8 @@ class AgentManager(metaclass=Singleton):
plugins_reply = "" plugins_reply = ""
for i, plugin in enumerate(self.cfg.plugins): for i, plugin in enumerate(self.cfg.plugins):
if not plugin.can_handle_on_instruction():
continue
plugin_result = plugin.on_instruction(messages) plugin_result = plugin.on_instruction(messages)
if plugin_result: if plugin_result:
sep = "" if not i else "\n" sep = "" if not i else "\n"
@@ -61,6 +65,8 @@ class AgentManager(metaclass=Singleton):
self.agents[key] = (task, messages, model) self.agents[key] = (task, messages, model)
for plugin in self.cfg.plugins: for plugin in self.cfg.plugins:
if not plugin.can_handle_post_instruction():
continue
agent_reply = plugin.post_instruction(agent_reply) agent_reply = plugin.post_instruction(agent_reply)
return key, agent_reply return key, agent_reply
@@ -81,6 +87,8 @@ class AgentManager(metaclass=Singleton):
messages.append({"role": "user", "content": message}) messages.append({"role": "user", "content": message})
for plugin in self.cfg.plugins: for plugin in self.cfg.plugins:
if not plugin.can_handle_pre_instruction():
continue
plugin_messages = plugin.pre_instruction(messages) plugin_messages = plugin.pre_instruction(messages)
if plugin_messages: if plugin_messages:
for plugin_message in plugin_messages: for plugin_message in plugin_messages:
@@ -96,6 +104,8 @@ class AgentManager(metaclass=Singleton):
plugins_reply = agent_reply plugins_reply = agent_reply
for i, plugin in enumerate(self.cfg.plugins): for i, plugin in enumerate(self.cfg.plugins):
if not plugin.can_handle_on_instruction():
continue
plugin_result = plugin.on_instruction(messages) plugin_result = plugin.on_instruction(messages)
if plugin_result: if plugin_result:
sep = "" if not i else "\n" sep = "" if not i else "\n"
@@ -105,6 +115,8 @@ class AgentManager(metaclass=Singleton):
messages.append({"role": "assistant", "content": plugins_reply}) messages.append({"role": "assistant", "content": plugins_reply})
for plugin in self.cfg.plugins: for plugin in self.cfg.plugins:
if not plugin.can_handle_post_instruction():
continue
agent_reply = plugin.post_instruction(agent_reply) agent_reply = plugin.post_instruction(agent_reply)
return agent_reply return agent_reply

View File

@@ -137,6 +137,8 @@ def chat_with_ai(
plugin_count = len(cfg.plugins) plugin_count = len(cfg.plugins)
for i, plugin in enumerate(cfg.plugins): for i, plugin in enumerate(cfg.plugins):
if not plugin.can_handle_on_planning():
continue
plugin_response = plugin.on_planning( plugin_response = plugin.on_planning(
agent.prompt_generator, current_context agent.prompt_generator, current_context
) )

View File

@@ -131,6 +131,8 @@ def create_chat_completion(
raise RuntimeError(f"Failed to get response after {num_retries} retries") raise RuntimeError(f"Failed to get response after {num_retries} retries")
resp = response.choices[0].message["content"] resp = response.choices[0].message["content"]
for plugin in CFG.plugins: for plugin in CFG.plugins:
if not plugin.can_handle_on_response():
continue
resp = plugin.on_response(resp) resp = plugin.on_response(resp)
return resp return resp