You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
256 lines
9.6 KiB
256 lines
9.6 KiB
import asyncio
|
|
import os
|
|
import signal
|
|
import traceback
|
|
from typing import Optional
|
|
|
|
import typer
|
|
|
|
from ...utils import ANSI
|
|
from ._cli_hacks import _async_prompt, _patch_anyio_open_process
|
|
from .agent import Agent
|
|
from .utils import _load_agent_config
|
|
|
|
|
|
app = typer.Typer(
|
|
rich_markup_mode="rich",
|
|
help="A squad of lightweight composable AI applications built on Hugging Face's Inference Client and MCP stack.",
|
|
)
|
|
|
|
run_cli = typer.Typer(
|
|
name="run",
|
|
help="Run the Agent in the CLI",
|
|
invoke_without_command=True,
|
|
)
|
|
app.add_typer(run_cli, name="run")
|
|
|
|
|
|
async def run_agent(
|
|
agent_path: Optional[str],
|
|
) -> None:
|
|
"""
|
|
Tiny Agent loop.
|
|
|
|
Args:
|
|
agent_path (`str`, *optional*):
|
|
Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` or `AGENTS.md` file or a built-in agent stored in a Hugging Face dataset.
|
|
|
|
"""
|
|
_patch_anyio_open_process() # Hacky way to prevent stdio connections to be stopped by Ctrl+C
|
|
|
|
config, prompt = _load_agent_config(agent_path)
|
|
|
|
inputs = config.get("inputs", [])
|
|
servers = config.get("servers", [])
|
|
|
|
abort_event = asyncio.Event()
|
|
exit_event = asyncio.Event()
|
|
first_sigint = True
|
|
|
|
loop = asyncio.get_running_loop()
|
|
original_sigint_handler = signal.getsignal(signal.SIGINT)
|
|
|
|
def _sigint_handler() -> None:
|
|
nonlocal first_sigint
|
|
if first_sigint:
|
|
first_sigint = False
|
|
abort_event.set()
|
|
print(ANSI.red("\nInterrupted. Press Ctrl+C again to quit."), flush=True)
|
|
return
|
|
|
|
print(ANSI.red("\nExiting..."), flush=True)
|
|
exit_event.set()
|
|
|
|
try:
|
|
sigint_registered_in_loop = False
|
|
try:
|
|
loop.add_signal_handler(signal.SIGINT, _sigint_handler)
|
|
sigint_registered_in_loop = True
|
|
except (AttributeError, NotImplementedError):
|
|
# Windows (or any loop that doesn't support it) : fall back to sync
|
|
signal.signal(signal.SIGINT, lambda *_: _sigint_handler())
|
|
|
|
# Handle inputs (i.e. env variables injection)
|
|
resolved_inputs: dict[str, str] = {}
|
|
|
|
if len(inputs) > 0:
|
|
print(
|
|
ANSI.bold(
|
|
ANSI.blue(
|
|
"Some initial inputs are required by the agent. "
|
|
"Please provide a value or leave empty to load from env."
|
|
)
|
|
)
|
|
)
|
|
for input_item in inputs:
|
|
input_id = input_item["id"]
|
|
description = input_item["description"]
|
|
env_special_value = f"${{input:{input_id}}}"
|
|
|
|
# Check if the input is used by any server or as an apiKey
|
|
input_usages = set()
|
|
for server in servers:
|
|
# Check stdio's "env" and http/sse's "headers" mappings
|
|
env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
|
|
for key, value in env_or_headers.items():
|
|
if env_special_value in value:
|
|
input_usages.add(key)
|
|
|
|
raw_api_key = config.get("apiKey")
|
|
if isinstance(raw_api_key, str) and env_special_value in raw_api_key:
|
|
input_usages.add("apiKey")
|
|
|
|
if not input_usages:
|
|
print(
|
|
ANSI.yellow(
|
|
f"Input '{input_id}' defined in config but not used by any server or as an API key."
|
|
" Skipping."
|
|
)
|
|
)
|
|
continue
|
|
|
|
# Prompt user for input
|
|
env_variable_key = input_id.replace("-", "_").upper()
|
|
print(
|
|
ANSI.blue(f" • {input_id}") + f": {description}. (default: load from {env_variable_key}).",
|
|
end=" ",
|
|
)
|
|
user_input = (await _async_prompt(exit_event=exit_event)).strip()
|
|
if exit_event.is_set():
|
|
return
|
|
|
|
# Fallback to environment variable when user left blank
|
|
final_value = user_input
|
|
if not final_value:
|
|
final_value = os.getenv(env_variable_key, "")
|
|
if final_value:
|
|
print(ANSI.green(f"Value successfully loaded from '{env_variable_key}'"))
|
|
else:
|
|
print(
|
|
ANSI.yellow(
|
|
f"No value found for '{env_variable_key}' in environment variables. Continuing."
|
|
)
|
|
)
|
|
resolved_inputs[input_id] = final_value
|
|
|
|
# Inject resolved value (can be empty) into stdio's env or http/sse's headers
|
|
for server in servers:
|
|
env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
|
|
for key, value in env_or_headers.items():
|
|
if env_special_value in value:
|
|
env_or_headers[key] = env_or_headers[key].replace(env_special_value, final_value)
|
|
|
|
print()
|
|
|
|
raw_api_key = config.get("apiKey")
|
|
if isinstance(raw_api_key, str):
|
|
substituted_api_key = raw_api_key
|
|
for input_id, val in resolved_inputs.items():
|
|
substituted_api_key = substituted_api_key.replace(f"${{input:{input_id}}}", val)
|
|
config["apiKey"] = substituted_api_key
|
|
# Main agent loop
|
|
async with Agent(
|
|
provider=config.get("provider"), # type: ignore[arg-type]
|
|
model=config.get("model"),
|
|
base_url=config.get("endpointUrl"), # type: ignore[arg-type]
|
|
api_key=config.get("apiKey"),
|
|
servers=servers, # type: ignore[arg-type]
|
|
prompt=prompt,
|
|
) as agent:
|
|
await agent.load_tools()
|
|
print(ANSI.bold(ANSI.blue("Agent loaded with {} tools:".format(len(agent.available_tools)))))
|
|
for t in agent.available_tools:
|
|
print(ANSI.blue(f" • {t.function.name}"))
|
|
|
|
while True:
|
|
abort_event.clear()
|
|
|
|
# Check if we should exit
|
|
if exit_event.is_set():
|
|
return
|
|
|
|
try:
|
|
user_input = await _async_prompt(exit_event=exit_event)
|
|
first_sigint = True
|
|
except EOFError:
|
|
print(ANSI.red("\nEOF received, exiting."), flush=True)
|
|
break
|
|
except KeyboardInterrupt:
|
|
if not first_sigint and abort_event.is_set():
|
|
continue
|
|
else:
|
|
print(ANSI.red("\nKeyboard interrupt during input processing."), flush=True)
|
|
break
|
|
|
|
try:
|
|
async for chunk in agent.run(user_input, abort_event=abort_event):
|
|
if abort_event.is_set() and not first_sigint:
|
|
break
|
|
if exit_event.is_set():
|
|
return
|
|
|
|
if hasattr(chunk, "choices"):
|
|
delta = chunk.choices[0].delta
|
|
if delta.content:
|
|
print(delta.content, end="", flush=True)
|
|
if delta.tool_calls:
|
|
for call in delta.tool_calls:
|
|
if call.id:
|
|
print(f"<Tool {call.id}>", end="")
|
|
if call.function.name:
|
|
print(f"{call.function.name}", end=" ")
|
|
if call.function.arguments:
|
|
print(f"{call.function.arguments}", end="")
|
|
else:
|
|
print(
|
|
ANSI.green(f"\n\nTool[{chunk.name}] {chunk.tool_call_id}\n{chunk.content}\n"),
|
|
flush=True,
|
|
)
|
|
|
|
print()
|
|
|
|
except Exception as e:
|
|
tb_str = traceback.format_exc()
|
|
print(ANSI.red(f"\nError during agent run: {e}\n{tb_str}"), flush=True)
|
|
first_sigint = True # Allow graceful interrupt for the next command
|
|
|
|
except Exception as e:
|
|
tb_str = traceback.format_exc()
|
|
print(ANSI.red(f"\nAn unexpected error occurred: {e}\n{tb_str}"), flush=True)
|
|
raise e
|
|
|
|
finally:
|
|
if sigint_registered_in_loop:
|
|
try:
|
|
loop.remove_signal_handler(signal.SIGINT)
|
|
except (AttributeError, NotImplementedError):
|
|
pass
|
|
else:
|
|
signal.signal(signal.SIGINT, original_sigint_handler)
|
|
|
|
|
|
@run_cli.callback()
|
|
def run(
|
|
path: Optional[str] = typer.Argument(
|
|
None,
|
|
help=(
|
|
"Path to a local folder containing an agent.json file or a built-in agent "
|
|
"stored in the 'tiny-agents/tiny-agents' Hugging Face dataset "
|
|
"(https://huggingface.co/datasets/tiny-agents/tiny-agents)"
|
|
),
|
|
show_default=False,
|
|
),
|
|
):
|
|
try:
|
|
asyncio.run(run_agent(path))
|
|
except KeyboardInterrupt:
|
|
print(ANSI.red("\nApplication terminated by KeyboardInterrupt."), flush=True)
|
|
raise typer.Exit(code=130)
|
|
except Exception as e:
|
|
print(ANSI.red(f"\nAn unexpected error occurred: {e}"), flush=True)
|
|
raise e
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app()
|