diff --git a/anton/__init__.py b/anton/__init__.py
index 49e0fc1..3e2f46a 100644
--- a/anton/__init__.py
+++ b/anton/__init__.py
@@ -1 +1 @@
-__version__ = "0.7.0"
+__version__ = "0.9.0"
diff --git a/anton/chat.py b/anton/chat.py
index c068d69..2bc415b 100644
--- a/anton/chat.py
+++ b/anton/chat.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
+import concurrent.futures
import json as _json
import os
import re as _re
@@ -52,6 +53,10 @@
_YAML_BLOCK_RE,
)
+from prompt_toolkit import PromptSession
+from prompt_toolkit.formatted_text import HTML
+from prompt_toolkit.key_binding import KeyBindings
+from prompt_toolkit.styles import Style as PTStyle
from rich.prompt import Confirm, Prompt
if TYPE_CHECKING:
@@ -78,6 +83,12 @@
"Only involve the user if the problem truly requires something only they can provide."
)
+# ── Interactive prompt copy ───────────────────────────────────────────────────
+_PROMPT_YN = "(y/n)"
+_PROMPT_RECONNECT_CANCEL = "(reconnect/cancel)"
+_PROMPT_SELECT_Q = "(or q to cancel)"
+_PROMPT_MEMORY_SAVE = "(y/n/pick numbers)"
+
class ChatSession:
"""Manages a multi-turn conversation with tool-call delegation."""
@@ -1710,6 +1721,62 @@ def _mask_secret(value: str, *, keep: int = 4) -> str:
return f"{value[:keep]}...{value[-keep:]}"
+def _prompt_or_cancel(
+ label: str,
+ *,
+ default: str = "",
+ password: bool = False,
+) -> str | None:
+ """Prompt for free-text input; return None if the user presses Esc.
+
+ Uses the same ESC-detection pattern as _setup_prompt() in cli.py:
+ a prompt_toolkit session with an Esc key binding that exits the session
+ and signals cancellation. Works from both sync and async contexts.
+ """
+ _esc = False
+ bindings = KeyBindings()
+
+ @bindings.add("escape")
+ @bindings.add("c-c")
+ def _on_esc(event):
+ nonlocal _esc
+ _esc = True
+ event.app.exit(result="")
+
+ pt_style = PTStyle.from_dict({"bottom-toolbar": "noreverse nounderline bg:default"})
+
+ def _toolbar():
+ return HTML("")
+
+ suffix = f" ({default}): " if default else ": "
+ session: PromptSession[str] = PromptSession(
+ mouse_support=False,
+ bottom_toolbar=_toolbar,
+ style=pt_style,
+ key_bindings=bindings,
+ is_password=password,
+ )
+
+ try:
+ asyncio.get_running_loop()
+ in_async = True
+ except RuntimeError:
+ in_async = False
+
+ if in_async:
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
+ fut = pool.submit(session.prompt, f"{label}{suffix}")
+ result = fut.result()
+ else:
+ result = session.prompt(f"{label}{suffix}")
+
+ if _esc:
+ return None
+ if not result and default:
+ return default
+ return result
+
+
def _prompt_minds_api_key(
console: Console,
*,
@@ -2617,26 +2684,23 @@ async def _handle_add_custom_datasource(
"""Ask for the tool name, use the LLM to identify required fields, then collect credentials."""
console.print()
- preamble = "[anton.cyan](anton)[/] "
if name:
tool_name = name
name_context = f"'{name}' isn't in my built-in list.\n "
else:
- tool_name = Prompt.ask(
- f"{preamble}What is the name of the tool or service?",
- console=console,
+ tool_name = _prompt_or_cancel(
+ "(anton) What is the name of the tool or service?",
)
- if not tool_name.strip():
+ if not tool_name or not tool_name.strip():
return None
tool_name = tool_name.strip()
name_context = ""
- user_answer = Prompt.ask(
- f"{preamble}{name_context}How do you authenticate with it? "
+ user_answer = _prompt_or_cancel(
+ f"(anton) {name_context}How do you authenticate with it? "
"Describe what credentials you have (don't paste actual values)",
- console=console,
)
- if not user_answer.strip():
+ if not user_answer or not user_answer.strip():
return None
console.print()
@@ -2870,24 +2934,18 @@ async def _run_connection_test(
if cell.error or (cell.stdout.strip() != "ok" and cell.stderr.strip()):
error_text = cell.error or cell.stderr.strip() or cell.stdout.strip()
- first_line = next(
- (ln for ln in error_text.splitlines() if ln.strip()), error_text
+ last_line = next(
+ (ln for ln in reversed(error_text.splitlines()) if ln.strip()), error_text
)
console.print()
console.print("[anton.warning](anton)[/] ✗ Connection failed.")
console.print()
- console.print(f" Error: {first_line}")
+ console.print(f" Error: {last_line}")
console.print()
- retry = (
- Prompt.ask(
- "[anton.cyan](anton)[/] Would you like to re-enter your credentials? [y/n]",
- console=console,
- default="n",
- )
- .strip()
- .lower()
+ retry = _prompt_or_cancel(
+ "(anton) Would you like to re-enter your credentials? (y/n)",
)
- if retry != "y":
+ if retry is None or retry.strip().lower() != "y":
return False
console.print()
for f in retry_fields:
@@ -3008,64 +3066,10 @@ async def _handle_connect_datasource(
credentials[f.name] = value
if engine_def.test_snippet:
- while True:
- console.print()
- console.print("[anton.cyan](anton)[/] Got it. Testing connection…")
-
- # Temporarily save credentials so inject_env(flat=True) can load them,
- # then restore all namespaced env vars in the finally block.
- vault.save(edit_engine, edit_name, credentials)
- vault.clear_ds_env()
- vault.inject_env(edit_engine, edit_name, flat=True)
- _register_secret_vars(engine_def) # flat names for scrubbing during test
- try:
- pad = await scratchpads.get_or_create("__datasource_test__")
- await pad.reset()
- if engine_def.pip:
- await pad.install_packages([engine_def.pip])
- cell = await pad.execute(engine_def.test_snippet)
- finally:
- _restore_namespaced_env(vault)
-
- if cell.error or (cell.stdout.strip() != "ok" and cell.stderr.strip()):
- error_text = (
- cell.error or cell.stderr.strip() or cell.stdout.strip()
- )
- first_line = next(
- (ln for ln in error_text.splitlines() if ln.strip()), error_text
- )
- console.print()
- console.print("[anton.warning](anton)[/] ✗ Connection failed.")
- console.print()
- console.print(f" Error: {first_line}")
- console.print()
- retry = (
- Prompt.ask(
- "[anton.cyan](anton)[/] Would you like to re-enter your credentials? [y/n]",
- console=console,
- default="n",
- )
- .strip()
- .lower()
- )
- if retry != "y":
- return session
- console.print()
- for f in active_fields:
- if not f.secret:
- continue
- value = Prompt.ask(
- f"[anton.cyan](anton)[/] {f.name}",
- password=True,
- console=console,
- default="",
- )
- if value:
- credentials[f.name] = value
- continue
-
- console.print("[anton.success] ✓ Connected successfully![/]")
- break
+ if not await _run_connection_test(
+ console, scratchpads, vault, engine_def, credentials, active_fields
+ ):
+ return session
vault.save(edit_engine, edit_name, credentials)
_restore_namespaced_env(vault)
@@ -3092,25 +3096,84 @@ async def _handle_connect_datasource(
console.print()
all_engines = registry.all_engines()
+ popular_engines = [e for e in all_engines if e.popular and not e.custom]
+ other_engines = [e for e in all_engines if not e.popular and not e.custom]
+ custom_engines = [e for e in all_engines if e.custom]
+ display_engines = popular_engines + other_engines + custom_engines
+
+ def _print_sections() -> None:
+ console.print(
+ "[anton.cyan](anton)[/] Choose a data source:\n"
+ )
+ console.print(" [bold] Primary")
+ console.print(
+ " [bold] 0.[/bold] Custom datasource"
+ " (connect anything via API, SQL, or MCP)\n"
+ )
+ if popular_engines:
+ console.print(" [bold] Most popular")
+ for i, e in enumerate(popular_engines, 1):
+ console.print(f" [bold]{i:>2}.[/bold] {e.display_name}")
+ console.print()
+ if other_engines:
+ start = len(popular_engines) + 1
+ console.print(" [bold] Other connectors")
+ for i, e in enumerate(other_engines[:3], start):
+ console.print(f" [bold]{i:>2}.[/bold] {e.display_name}")
+ if len(other_engines) > 3:
+ console.print(
+ f" [anton.muted] … and"
+ f" {len(other_engines) - 3} more"
+ f" (type 'all' to see all)[/]"
+ )
+ console.print()
+ if custom_engines:
+ start = len(popular_engines) + len(other_engines) + 1
+ console.print(" [bold] Custom")
+ for i, e in enumerate(custom_engines, start):
+ console.print(f" [bold]{i:>2}.[/bold] {e.display_name}")
+ console.print()
+
+ def _print_all() -> None:
+ console.print(
+ "[anton.cyan](anton)[/] All data sources (★ = popular):\n"
+ )
+ console.print(" [bold] Primary")
+ console.print(
+ " [bold] 0.[/bold] Custom datasource"
+ " (connect anything via API, SQL, or MCP)\n"
+ )
+ for i, e in enumerate(display_engines, 1):
+ star = " ★" if e.popular else ""
+ console.print(f" [bold]{i:>2}.[/bold] {e.display_name}{star}")
+ console.print()
+
if prefill:
answer = prefill
else:
+ _print_sections()
console.print(
- "[anton.cyan](anton)[/] Choose a data source:\n"
+ " [anton.muted]Type 'all' to see every datasource.[/]"
)
- console.print(" [bold] Primary")
- console.print(" [bold] 0.[/bold] Custom datasource (connect anything via API, SQL, or MCP)\n")
- console.print(" [bold] Predefined")
- for i, e in enumerate(all_engines, 1):
- console.print(f" [bold]{i:>2}.[/bold] {e.display_name}")
console.print()
- answer = Prompt.ask(
- "[anton.cyan](anton)[/] Enter a number or type a name",
- console=console,
+ answer = _prompt_or_cancel(
+ "(anton) Enter a number or type a name",
)
+ if answer is None:
+ return session
+ if answer.strip().lower() == "all":
+ console.print()
+ _print_all()
+ answer = _prompt_or_cancel(
+ "(anton) Enter a number or type a name",
+ )
+ if answer is None:
+ return session
stripped_answer = answer.strip()
- known_slugs = {f"{c['engine']}-{c['name']}": c for c in vault.list_connections()}
+ known_slugs = {
+ f"{c['engine']}-{c['name']}": c for c in vault.list_connections()
+ }
if stripped_answer in known_slugs:
conn = known_slugs[stripped_answer]
_restore_namespaced_env(vault)
@@ -3144,12 +3207,12 @@ async def _handle_connect_datasource(
pick_num = int(stripped_answer)
if pick_num == 0:
custom_source = True
- elif 1 <= pick_num <= len(all_engines):
- engine_def = all_engines[pick_num - 1]
+ elif 1 <= pick_num <= len(display_engines):
+ engine_def = display_engines[pick_num - 1]
else:
console.print(
f"[anton.warning](anton)[/] '{stripped_answer}' is out of range. "
- f"Please enter 0–{len(all_engines)}.[/]"
+ f"Please enter 0–{len(display_engines)}.[/]"
)
console.print()
return session
@@ -3194,16 +3257,10 @@ async def _handle_connect_datasource(
console.print(
f'[anton.cyan](anton)[/] Did you mean [bold]"{suggestion.display_name}"[/bold]?'
)
- confirm = (
- Prompt.ask(
- "[anton.cyan](anton)[/] [y/n]",
- console=console,
- default="n",
- )
- .strip()
- .lower()
+ confirm = _prompt_or_cancel(
+ "(anton) Use this datasource? (y/n)",
)
- if confirm == "y":
+ if confirm is not None and confirm.strip().lower() == "y":
engine_def = suggestion
break
@@ -3299,14 +3356,12 @@ async def _handle_connect_datasource(
console.print()
- mode_answer = (
- Prompt.ask(
- "[anton.cyan](anton)[/] Do you have these available? [y/n/]",
- console=console,
- )
- .strip()
- .lower()
+ mode_answer = _prompt_or_cancel(
+ "(anton) Do you have these available? (y/n/)",
)
+ if mode_answer is None:
+ return session
+ mode_answer = mode_answer.strip().lower()
if mode_answer == "n":
console.print()
@@ -3385,16 +3440,10 @@ async def _handle_connect_datasource(
f'[anton.warning](anton)[/] A connection [bold]"{slug}"[/bold] already exists.'
)
console.print()
- choice = (
- Prompt.ask(
- "[anton.cyan](anton)[/] [reconnect/cancel]",
- console=console,
- default="cancel",
- )
- .strip()
- .lower()
+ choice = _prompt_or_cancel(
+ f"(anton) {_PROMPT_RECONNECT_CANCEL}",
)
- if choice != "reconnect":
+ if choice is None or choice.strip().lower() != "reconnect":
console.print("[anton.muted]Cancelled.[/]")
console.print()
return session
@@ -3503,6 +3552,19 @@ def _handle_remove_data_source(console: Console, slug: str) -> None:
):
vault.delete(engine, name)
_restore_namespaced_env(vault)
+ engine_def = registry.get(engine)
+ if engine_def is not None and engine_def.custom:
+ remaining = [
+ c for c in vault.list_connections() if c["engine"] == engine
+ ]
+ if not remaining:
+ user_path = DatasourceRegistry._USER_PATH
+ if user_path.is_file():
+ updated = _remove_engine_block(
+ user_path.read_text(encoding="utf-8"), engine
+ )
+ user_path.write_text(updated, encoding="utf-8")
+ registry.reload()
console.print(f"[anton.success]Removed {slug}.[/]")
else:
console.print("[anton.muted]Cancelled.[/]")
@@ -3874,10 +3936,6 @@ async def _chat_loop(
toolbar = {"stats": "", "status": ""}
display = StreamDisplay(console, toolbar=toolbar)
- from prompt_toolkit import PromptSession
- from prompt_toolkit.formatted_text import HTML
- from prompt_toolkit.styles import Style as PTStyle
-
def _bottom_toolbar():
stats = toolbar["stats"]
status = toolbar["status"]
diff --git a/anton/config/datasources.md b/anton/config/datasources.md
index 2cafc88..bd4773d 100644
--- a/anton/config/datasources.md
+++ b/anton/config/datasources.md
@@ -16,6 +16,7 @@ engine: postgres
display_name: PostgreSQL
pip: psycopg2-binary
name_from: database
+popular: true
fields:
- { name: host, required: true, secret: false, description: "hostname or IP of your database server" }
- { name: port, required: true, secret: false, description: "port number", default: "5432" }
@@ -47,6 +48,7 @@ engine: mysql
display_name: MySQL
pip: mysql-connector-python
name_from: database
+popular: true
fields:
- { name: host, required: true, secret: false, description: "hostname or IP of your MySQL server" }
- { name: port, required: true, secret: false, description: "port number", default: "3306" }
@@ -78,6 +80,7 @@ test_snippet: |
engine: snowflake
display_name: Snowflake
pip: snowflake-connector-python
+popular: true
name_from: [account, database]
auth_method: choice
auth_methods:
@@ -127,6 +130,7 @@ Format is either `-` or `..
engine: bigquery
display_name: Google BigQuery
pip: google-cloud-bigquery
+popular: true
name_from: [project_id, dataset]
fields:
- { name: project_id, required: true, secret: false, description: "GCP project ID containing your BigQuery datasets" }
@@ -162,6 +166,7 @@ Keys → Add Key → JSON. Grant the account `BigQuery Data Viewer` + `BigQuery
engine: mssql
display_name: Microsoft SQL Server
pip: pymssql
+popular: true
name_from: database
fields:
- { name: host, required: true, secret: false, description: "hostname or IP of the SQL Server (for Azure use server field instead)" }
@@ -195,6 +200,7 @@ For Windows Authentication omit user/password and ensure pymssql is built with K
engine: redshift
display_name: Amazon Redshift
pip: psycopg2-binary
+popular: true
name_from: [host, database]
fields:
- { name: host, required: true, secret: false, description: "Redshift cluster endpoint (e.g. mycluster.abc123.us-east-1.redshift.amazonaws.com)" }
@@ -229,6 +235,7 @@ Redshift → Clusters → your cluster → Endpoint (omit the port suffix).
engine: databricks
display_name: Databricks
pip: databricks-sql-connector
+popular: true
name_from: [server_hostname, catalog]
fields:
- { name: server_hostname, required: true, secret: false, description: "server hostname for the cluster or SQL warehouse (from JDBC/ODBC connection string)" }
@@ -262,6 +269,7 @@ HTTP path and server hostname: SQL Warehouses → your warehouse → Connection
engine: mariadb
display_name: MariaDB
pip: mysql-connector-python
+popular: true
name_from: database
fields:
- { name: host, required: true, secret: false, description: "hostname or IP of your MariaDB server" }
@@ -296,6 +304,7 @@ MariaDB is wire-compatible with MySQL, so the mysql-connector-python driver work
engine: hubspot
display_name: HubSpot
pip: hubspot-api-client
+popular: true
auth_method: choice
auth_methods:
- name: pat
@@ -338,6 +347,7 @@ For OAuth2: collect client_id and client_secret, then use the scratchpad to:
engine: oracle_database
display_name: Oracle Database
pip: oracledb
+popular: true
name_from: [host, service_name]
fields:
- { name: user, required: true, secret: false, description: "Oracle database username" }
@@ -371,6 +381,7 @@ Set `auth_mode` to `SYSDBA` or `SYSOPER` for privileged connections.
engine: duckdb
display_name: DuckDB
pip: duckdb
+popular: false
name_from: database
fields:
- { name: database, required: false, secret: false, description: "path to DuckDB database file; omit or use :memory: for in-memory database", default: ":memory:" }
@@ -400,6 +411,7 @@ the access token. For local files, provide the path to a `.duckdb` file.
engine: pgvector
display_name: pgvector
pip: pgvector psycopg2-binary
+popular: false
name_from: database
fields:
- { name: host, required: true, secret: false, description: "hostname or IP of your PostgreSQL server with pgvector extension" }
@@ -437,6 +449,7 @@ Managed options: Supabase, Neon, and AWS RDS for PostgreSQL all support pgvector
engine: chromadb
display_name: ChromaDB
pip: chromadb
+popular: false
name_from: host
fields:
- { name: host, required: true, secret: false, description: "ChromaDB server host for HTTP client mode (omit for local in-process mode)" }
@@ -471,6 +484,7 @@ or ephemeral in-memory. For production, run `chroma run` to start the HTTP serve
engine: salesforce
display_name: Salesforce
pip: salesforce_api
+popular: true
name_from: username
fields:
- { name: username, required: true, secret: false, description: "Salesforce account username (email)" }
@@ -502,6 +516,7 @@ Enable OAuth, add callback URL, select scopes (api, refresh_token).
engine: shopify
display_name: Shopify
pip: ShopifyAPI
+popular: true
name_from: shop_url
fields:
- { name: shop_url, required: true, secret: false, description: "your Shopify store URL (e.g. mystore.myshopify.com)" }
@@ -530,6 +545,7 @@ Grant required API permissions (read_products, read_orders, etc.) then install t
engine: netsuite
display_name: NetSuite
pip: requests-oauthlib>=1.3.1
+popular: false
name_from: account_id
fields:
- { name: account_id, required: true, secret: false, description: "NetSuite account/realm ID (e.g. 123456_SB1)" }
@@ -570,6 +586,7 @@ Setup → Users/Roles → Access Tokens. The account ID can be found in Setup
engine: bigcommerce
display_name: Big Commerce
pip: httpx
+popular: false
name_from: store_hash
fields:
- { name: api_base, required: true, secret: false, description: "Base URL of the BigCommerce API (e.g. https://api.bigcommerce.com/stores/0fh0fh0fh0/v3/)" }
@@ -629,6 +646,7 @@ and self-hosted PostgreSQL with the TimescaleDB extension installed.
engine: email
display_name: Email
name_from: email
+popular: false
fields:
- { name: email, required: true, secret: false, description: "email address to connect" }
- { name: password, required: true, secret: true, description: "email account password or app-specific password" }
diff --git a/anton/datasource_registry.py b/anton/datasource_registry.py
index 5777856..e932807 100644
--- a/anton/datasource_registry.py
+++ b/anton/datasource_registry.py
@@ -37,6 +37,9 @@ class DatasourceEngine:
auth_method: str = ""
auth_methods: list[AuthMethod] = field(default_factory=list)
test_snippet: str = ""
+ popular: bool = False
+ # True for engines defined in ~/.anton/datasources.md
+ custom: bool = False
# Matches a level-2 heading followed by a ```yaml fenced block.
@@ -63,7 +66,9 @@ def _parse_fields(raw: list) -> list[DatasourceField]:
return result
-def _parse_file(path: Path) -> dict[str, DatasourceEngine]:
+def _parse_file(
+ path: Path, *, custom: bool = False
+) -> dict[str, DatasourceEngine]:
"""Extract engine definitions from a datasources.md file."""
if not path.is_file():
return {}
@@ -108,6 +113,8 @@ def _parse_file(path: Path) -> dict[str, DatasourceEngine]:
auth_method=str(data.get("auth_method", "")),
auth_methods=auth_methods,
test_snippet=str(data.get("test_snippet", "")),
+ popular=bool(data.get("popular", False)),
+ custom=custom,
)
return engines
@@ -125,7 +132,9 @@ def __init__(self) -> None:
def _load(self) -> None:
self._engines = _parse_file(self._BUILTIN_PATH)
- for slug, engine in _parse_file(self._USER_PATH).items():
+ for slug, engine in _parse_file(
+ self._USER_PATH, custom=True
+ ).items():
self._engines[slug] = engine
def reload(self) -> None:
diff --git a/tests/test_datasource.py b/tests/test_datasource.py
index dc7a13a..364d915 100644
--- a/tests/test_datasource.py
+++ b/tests/test_datasource.py
@@ -13,6 +13,7 @@
ChatSession,
_DS_KNOWN_VARS,
_DS_SECRET_VARS,
+ _PROMPT_RECONNECT_CANCEL,
_build_datasource_context,
_handle_add_custom_datasource,
_handle_connect_datasource,
@@ -21,6 +22,7 @@
_handle_test_datasource,
_register_secret_vars,
_restore_namespaced_env,
+ _run_connection_test,
_scrub_credentials,
parse_connection_slug,
)
@@ -558,7 +560,7 @@ async def test_unknown_engine_returns_early(self, registry, vault_dir, make_sess
with (
patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)),
patch("anton.chat.DatasourceRegistry", return_value=registry),
- patch("rich.prompt.Prompt.ask", return_value="MySQL"),
+ patch("anton.chat._prompt_or_cancel", return_value="MySQL"),
):
result = await _handle_connect_datasource(console, session._scratchpads, session)
@@ -571,18 +573,21 @@ async def test_partial_save_on_n_answer(self, registry, vault_dir, make_session)
session = make_session()
console = MagicMock()
vault = DataVault(vault_dir=vault_dir)
- prompt_responses = iter(["PostgreSQL", "n", "db.example.com", "", "", "", "", ""])
+ poc_responses = iter(["PostgreSQL", "n"])
+ pa_responses = iter(["db.example.com", "", "", "", "", ""])
with (
patch("anton.chat.DataVault", return_value=vault),
patch("anton.chat.DatasourceRegistry", return_value=registry),
- patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
):
await _handle_connect_datasource(console, session._scratchpads, session)
conns = vault.list_connections()
assert len(conns) == 1
assert conns[0]["engine"] == "postgresql"
+ assert len(conns[0]["name"]) == 8 and all(c in "0123456789abcdef" for c in conns[0]["name"])
assert conns[0]["name"].isalnum()
session._scratchpads.get_or_create.assert_not_called()
@@ -599,15 +604,14 @@ async def test_successful_connection_saves_and_injects_history(
pad.execute = AsyncMock(return_value=make_cell(stdout="ok"))
session._scratchpads.get_or_create = AsyncMock(return_value=pad)
- prompt_responses = iter([
- "PostgreSQL", "y",
- "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "",
- ])
+ poc_responses = iter(["PostgreSQL", "y"])
+ pa_responses = iter(["db.example.com", "5432", "prod_db", "alice", "s3cr3t", ""])
with (
patch("anton.chat.DataVault", return_value=vault),
patch("anton.chat.DatasourceRegistry", return_value=registry),
- patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
):
result = await _handle_connect_datasource(console, session._scratchpads, session)
@@ -636,17 +640,17 @@ async def test_failed_test_offers_retry(self, registry, vault_dir, make_session,
])
session._scratchpads.get_or_create = AsyncMock(return_value=pad)
- prompt_responses = iter([
- "PostgreSQL", "y",
+ poc_responses = iter(["PostgreSQL", "y", "y"]) # engine, do-you-have, retry?
+ pa_responses = iter([
"db.example.com", "5432", "prod_db", "alice", "wrongpassword", "",
- "y", # retry?
"correctpassword",
])
with (
patch("anton.chat.DataVault", return_value=vault),
patch("anton.chat.DatasourceRegistry", return_value=registry),
- patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
):
await _handle_connect_datasource(console, session._scratchpads, session)
@@ -669,16 +673,14 @@ async def test_failed_test_no_retry_returns_without_saving(
pad.execute = AsyncMock(return_value=make_cell(stdout="", error="connection refused"))
session._scratchpads.get_or_create = AsyncMock(return_value=pad)
- prompt_responses = iter([
- "PostgreSQL", "y",
- "db.example.com", "5432", "prod_db", "alice", "badpass", "",
- "n",
- ])
+ poc_responses = iter(["PostgreSQL", "y", "n"]) # engine, do-you-have, retry?
+ pa_responses = iter(["db.example.com", "5432", "prod_db", "alice", "badpass", ""])
with (
patch("anton.chat.DataVault", return_value=vault),
patch("anton.chat.DatasourceRegistry", return_value=registry),
- patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
):
result = await _handle_connect_datasource(console, session._scratchpads, session)
@@ -698,15 +700,14 @@ async def test_ds_env_injected_after_successful_connect(
pad.execute = AsyncMock(return_value=make_cell(stdout="ok"))
session._scratchpads.get_or_create = AsyncMock(return_value=pad)
- prompt_responses = iter([
- "PostgreSQL", "y",
- "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "",
- ])
+ poc_responses = iter(["PostgreSQL", "y"])
+ pa_responses = iter(["db.example.com", "5432", "prod_db", "alice", "s3cr3t", ""])
with (
patch("anton.chat.DataVault", return_value=vault),
patch("anton.chat.DatasourceRegistry", return_value=registry),
- patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
):
await _handle_connect_datasource(console, session._scratchpads, session)
@@ -726,12 +727,14 @@ async def test_auth_method_choice_selects_fields(
pad.execute = AsyncMock(return_value=make_cell(stdout="ok"))
session._scratchpads.get_or_create = AsyncMock(return_value=pad)
- prompt_responses = iter(["HubSpot", "1", "y", "pat-na1-abc123"])
+ poc_responses = iter(["HubSpot", "y"]) # engine, do-you-have
+ pa_responses = iter(["1", "pat-na1-abc123"]) # auth method choice, field value
with (
patch("anton.chat.DataVault", return_value=vault),
patch("anton.chat.DatasourceRegistry", return_value=registry),
- patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
):
await _handle_connect_datasource(console, session._scratchpads, session)
@@ -757,15 +760,14 @@ async def test_selective_field_collection(
pad.execute = AsyncMock(return_value=make_cell(stdout="ok"))
session._scratchpads.get_or_create = AsyncMock(return_value=pad)
- prompt_responses = iter([
- "PostgreSQL", "host,user,password",
- "db.example.com", "alice", "s3cr3t",
- ])
+ poc_responses = iter(["PostgreSQL", "host,user,password"]) # engine, do-you-have (list params)
+ pa_responses = iter(["db.example.com", "alice", "s3cr3t"]) # field values
with (
patch("anton.chat.DataVault", return_value=vault),
patch("anton.chat.DatasourceRegistry", return_value=registry),
- patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
):
await _handle_connect_datasource(console, session._scratchpads, session)
@@ -843,15 +845,14 @@ async def test_register_and_scrub_on_connect(self, registry, vault_dir, monkeypa
session._scratchpads.get_or_create = AsyncMock(return_value=pad)
secret_pw = "supersecretpassword999"
- prompt_responses = iter([
- "PostgreSQL", "y",
- "db.host.com", "5432", "mydb", "alice", secret_pw, "public",
- ])
+ poc_responses = iter(["PostgreSQL", "y"])
+ pa_responses = iter(["db.host.com", "5432", "mydb", "alice", secret_pw, "public"])
with (
patch("anton.chat.DataVault", return_value=vault),
patch("anton.chat.DatasourceRegistry", return_value=registry),
- patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
):
await _handle_connect_datasource(MagicMock(), session._scratchpads, session)
@@ -1347,15 +1348,14 @@ async def test_connect_clears_previous_ds_vars(
pad.execute = AsyncMock(return_value=make_cell(stdout="ok"))
session._scratchpads.get_or_create = AsyncMock(return_value=pad)
- prompt_responses = iter([
- "PostgreSQL", "y",
- "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "",
- ])
+ poc_responses = iter(["PostgreSQL", "y"])
+ pa_responses = iter(["db.example.com", "5432", "prod_db", "alice", "s3cr3t", ""])
with (
patch("anton.chat.DataVault", return_value=vault),
patch("anton.chat.DatasourceRegistry", return_value=registry),
- patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
):
await _handle_connect_datasource(console, session._scratchpads, session)
@@ -1424,7 +1424,7 @@ async def test_edit_data_source_no_arg_safe(self, vault_dir, registry, make_sess
with (
patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)),
patch("anton.chat.DatasourceRegistry", return_value=registry),
- patch("rich.prompt.Prompt.ask", return_value="UnknownEngine"),
+ patch("anton.chat._prompt_or_cancel", return_value="UnknownEngine"),
):
updated = await _handle_connect_datasource(
console, session._scratchpads, session,
@@ -1729,14 +1729,9 @@ async def test_missing_required_non_secret_field_prompts_user(
console = MagicMock()
registry = self._make_registry(tmp_path)
- # first response: initial auth question; second: the host field prompt
- prompt_responses = iter(["I want to connect to mydb", "localhost"])
-
with (
- patch(
- "rich.prompt.Prompt.ask",
- side_effect=lambda *a, **kw: next(prompt_responses),
- ),
+ patch("anton.chat._prompt_or_cancel", return_value="I want to connect to mydb"),
+ patch("rich.prompt.Prompt.ask", return_value="localhost"),
patch("anton.chat.Path") as mock_path_cls,
):
self._mock_ds_path(mock_path_cls, tmp_path)
@@ -1766,12 +1761,11 @@ async def test_missing_required_secret_field_prompts_user(
password_calls = []
def fake_prompt(*args, **kwargs):
- if kwargs.get("password"):
- password_calls.append(kwargs)
- return "mysecret"
- return "I want to connect"
+ password_calls.append(kwargs)
+ return "mysecret"
with (
+ patch("anton.chat._prompt_or_cancel", return_value="I want to connect"),
patch("rich.prompt.Prompt.ask", side_effect=fake_prompt),
patch("anton.chat.Path") as mock_path_cls,
):
@@ -1804,14 +1798,9 @@ async def test_incomplete_custom_datasource_not_saved(
console = MagicMock()
registry = self._make_registry(tmp_path)
- # User presses Enter (empty) for every prompt
- prompt_responses = iter(["I want to connect", "", ""])
-
with (
- patch(
- "rich.prompt.Prompt.ask",
- side_effect=lambda *a, **kw: next(prompt_responses),
- ),
+ patch("anton.chat._prompt_or_cancel", return_value="I want to connect"),
+ patch("rich.prompt.Prompt.ask", return_value=""),
patch("anton.chat.Path") as mock_path_cls,
):
self._mock_ds_path(mock_path_cls, tmp_path)
@@ -1882,17 +1871,14 @@ async def test_custom_with_test_snippet_success(
test_snippet="print('ok')",
))
- prompt_responses = iter([
- "0", # choose custom
- "My API Service", # tool name
- "I have an API key", # auth description
- "my_secret_key", # api_key (secret prompt)
- ])
+ poc_responses = iter(["0", "My API Service", "I have an API key"]) # engine sel, tool name, auth desc
+ pa_responses = iter(["my_secret_key"])
with (
patch("anton.chat.DataVault", return_value=vault),
patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)),
- patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
patch("anton.chat.Path") as mock_path_cls,
):
self._mock_ds_path(mock_path_cls, tmp_path)
@@ -1924,18 +1910,14 @@ async def test_custom_with_test_snippet_fail_no_retry(
test_snippet="print('ok')",
))
- prompt_responses = iter([
- "0",
- "My API Service",
- "I have an API key",
- "bad_key", # api_key
- "n", # retry?
- ])
+ poc_responses = iter(["0", "My API Service", "I have an API key", "n"]) # engine sel, tool name, auth desc, retry?
+ pa_responses = iter(["bad_key"])
with (
patch("anton.chat.DataVault", return_value=vault),
patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)),
- patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
patch("anton.chat.Path") as mock_path_cls,
):
self._mock_ds_path(mock_path_cls, tmp_path)
@@ -1964,19 +1946,14 @@ async def test_custom_with_test_snippet_fail_retry_success(
test_snippet="print('ok')",
))
- prompt_responses = iter([
- "0",
- "My API Service",
- "I have an API key",
- "bad_key", # api_key first attempt
- "y", # retry?
- "good_key", # api_key retry
- ])
+ poc_responses = iter(["0", "My API Service", "I have an API key", "y"]) # engine sel, tool name, auth desc, retry?
+ pa_responses = iter(["bad_key", "good_key"])
with (
patch("anton.chat.DataVault", return_value=vault),
patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)),
- patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
patch("anton.chat.Path") as mock_path_cls,
):
self._mock_ds_path(mock_path_cls, tmp_path)
@@ -2005,17 +1982,14 @@ async def test_custom_without_test_snippet_saves(
test_snippet="",
))
- prompt_responses = iter([
- "0",
- "My API Service",
- "I have an API key",
- "my_key", # api_key
- ])
+ poc_responses = iter(["0", "My API Service", "I have an API key"]) # engine sel, tool name, auth desc
+ pa_responses = iter(["my_key"])
with (
patch("anton.chat.DataVault", return_value=vault),
patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)),
- patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
patch("anton.chat.Path") as mock_path_cls,
):
self._mock_ds_path(mock_path_cls, tmp_path)
@@ -2024,3 +1998,262 @@ async def test_custom_without_test_snippet_saves(
conns = vault.list_connections()
assert len(conns) == 1
pad.execute.assert_not_called()
+
+
+class TestEditDatasourceWithTestSnippet:
+ """Tests for /edit path: test_snippet runs before vault.save, not after."""
+
+ OLD_CREDS = {
+ "host": "pg.example.com",
+ "port": "5432",
+ "database": "prod_db",
+ "user": "alice",
+ "password": "good-pass",
+ "schema": "",
+ }
+
+ def _setup_pad(self, session, cell):
+ pad = AsyncMock()
+ pad.execute = AsyncMock(return_value=cell)
+ pad.reset = AsyncMock()
+ pad.install_packages = AsyncMock()
+ session._scratchpads.get_or_create = AsyncMock(return_value=pad)
+ return pad
+
+ @pytest.mark.asyncio
+ async def test_edit_failed_test_does_not_corrupt_vault(
+ self, vault_dir, registry, make_session, make_cell
+ ):
+ """edit with bad creds + test fails + user declines retry → original creds intact."""
+ session = make_session()
+ console = MagicMock()
+ vault = DataVault(vault_dir=vault_dir)
+ vault.save("postgresql", "prod_db", self.OLD_CREDS)
+
+ self._setup_pad(session, make_cell(stdout="", stderr="connection refused"))
+
+ # Keep all non-secret fields; enter bad password; decline retry.
+ poc_responses = iter(["n"]) # retry?
+ pa_responses = iter(["", "", "", "", "bad-pass", ""]) # field values
+
+ with (
+ patch("anton.chat.DataVault", return_value=vault),
+ patch("anton.chat.DatasourceRegistry", return_value=registry),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)),
+ patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)),
+ ):
+ result = await _handle_connect_datasource(
+ console, session._scratchpads, session,
+ datasource_name="postgresql-prod_db",
+ )
+
+ saved = vault.load("postgresql", "prod_db")
+ assert saved is not None
+ assert saved.get("password") == "good-pass"
+ assert result._history == []
+
+ @pytest.mark.asyncio
+ async def test_edit_successful_test_persists_new_credentials(
+ self, vault_dir, registry, make_session, make_cell
+ ):
+ """edit with valid creds + test passes → new creds saved to vault."""
+ session = make_session()
+ console = MagicMock()
+ vault = DataVault(vault_dir=vault_dir)
+ vault.save("postgresql", "prod_db", self.OLD_CREDS)
+
+ self._setup_pad(session, make_cell(stdout="ok"))
+
+ prompt_responses = iter([
+ "", # host
+ "", # port
+ "", # database
+ "", # user
+ "new-pass", # password (updated)
+ "", # schema
+ ])
+
+ with (
+ patch("anton.chat.DataVault", return_value=vault),
+ patch("anton.chat.DatasourceRegistry", return_value=registry),
+ patch(
+ "rich.prompt.Prompt.ask",
+ side_effect=lambda *a, **kw: next(prompt_responses),
+ ),
+ ):
+ result = await _handle_connect_datasource(
+ console, session._scratchpads, session,
+ datasource_name="postgresql-prod_db",
+ )
+
+ saved = vault.load("postgresql", "prod_db")
+ assert saved is not None
+ assert saved.get("password") == "new-pass"
+ assert result._history
+
+ @pytest.mark.asyncio
+ async def test_connection_test_error_summary_uses_meaningful_line(
+ self, vault_dir, registry
+ ):
+ """Error display shows last non-empty line (exception msg), not traceback header."""
+ console = MagicMock()
+ vault = DataVault(vault_dir=vault_dir)
+
+ traceback_text = (
+ "Traceback (most recent call last):\n"
+ " File \"test.py\", line 3, in \n"
+ " conn = psycopg2.connect(host=os.environ['DS_HOST'])\n"
+ "psycopg2.OperationalError: could not connect to server\n"
+ )
+ cell = MagicMock()
+ cell.stdout = ""
+ cell.stderr = traceback_text
+ cell.error = None
+
+ pad = AsyncMock()
+ pad.execute = AsyncMock(return_value=cell)
+ pad.reset = AsyncMock()
+ pad.install_packages = AsyncMock()
+
+ scratchpads = AsyncMock()
+ scratchpads.get_or_create = AsyncMock(return_value=pad)
+
+ engine_def = registry.get("postgresql")
+ credentials = {
+ "host": "bad-host", "port": "5432",
+ "database": "prod_db", "user": "alice", "password": "pw",
+ }
+
+ with (
+ patch("anton.chat.DataVault", return_value=vault),
+ patch("anton.chat.DatasourceRegistry", return_value=registry),
+ patch("anton.chat._prompt_or_cancel", return_value="n"),
+ ):
+ result = await _run_connection_test(
+ console, scratchpads, vault, engine_def, credentials,
+ retry_fields=engine_def.fields,
+ )
+
+ assert result is False
+ printed = " ".join(str(c) for c in console.print.call_args_list)
+ assert "psycopg2.OperationalError" in printed
+
+
+# ─────────────────────────────────────────────────────────────────────────────
+# Prompt copy consistency
+# ─────────────────────────────────────────────────────────────────────────────
+
+
+class TestPromptCopyConsistency:
+ """Verify that interactive prompts use (y/n) style and Esc cancels safely."""
+
+ @pytest.mark.asyncio
+ async def test_esc_on_engine_selection_returns_session_unchanged(
+ self, registry, vault_dir, make_session
+ ):
+ """Pressing Esc on the engine-selection prompt returns the session with no vault writes."""
+ session = make_session()
+ console = MagicMock()
+ vault = DataVault(vault_dir=vault_dir)
+
+ with (
+ patch("anton.chat.DataVault", return_value=vault),
+ patch("anton.chat.DatasourceRegistry", return_value=registry),
+ patch("anton.chat._prompt_or_cancel", return_value=None),
+ ):
+ result = await _handle_connect_datasource(console, session._scratchpads, session)
+
+ assert result is session
+ assert vault.list_connections() == []
+
+ @pytest.mark.asyncio
+ async def test_esc_on_retry_does_not_save(self, registry, vault_dir, make_session, make_cell):
+ """Pressing Esc at the retry prompt makes _run_connection_test return False."""
+ session = make_session()
+ console = MagicMock()
+ vault = DataVault(vault_dir=vault_dir)
+
+ pad = AsyncMock()
+ pad.execute = AsyncMock(return_value=make_cell(stdout="", error="bad creds"))
+ session._scratchpads.get_or_create = AsyncMock(return_value=pad)
+
+ with (
+ patch("anton.chat.DataVault", return_value=vault),
+ patch("anton.chat.DatasourceRegistry", return_value=registry),
+ patch("anton.chat._prompt_or_cancel", return_value=None),
+ ):
+ engine_def = registry.get("postgresql")
+ credentials = {"host": "h", "port": "5432", "database": "d", "user": "u", "password": "p"}
+ result = await _run_connection_test(
+ console, session._scratchpads, vault, engine_def, credentials,
+ retry_fields=engine_def.fields,
+ )
+
+ assert result is False
+ assert vault.list_connections() == []
+
+ @pytest.mark.asyncio
+ async def test_esc_on_do_you_have_these_returns_session(
+ self, registry, vault_dir, make_session
+ ):
+ """Pressing Esc after engine selection (on 'do you have these?') returns session."""
+ session = make_session()
+ console = MagicMock()
+ vault = DataVault(vault_dir=vault_dir)
+
+ poc_calls = iter(["PostgreSQL", None]) # engine selected, then Esc
+
+ with (
+ patch("anton.chat.DataVault", return_value=vault),
+ patch("anton.chat.DatasourceRegistry", return_value=registry),
+ patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_calls)),
+ ):
+ result = await _handle_connect_datasource(console, session._scratchpads, session)
+
+ assert result is session
+ assert vault.list_connections() == []
+
+ @pytest.mark.asyncio
+ async def test_fuzzy_match_prompt_has_context_text(self, registry, vault_dir, make_session):
+ """The fuzzy-match confirmation prompt includes context text and uses (y/n)."""
+ session = make_session()
+ console = MagicMock()
+ vault = DataVault(vault_dir=vault_dir)
+
+ captured_labels: list[str] = []
+
+ def _capture(label, **kw):
+ captured_labels.append(label)
+ return None # Esc on every prompt to bail out
+
+ with (
+ patch("anton.chat.DataVault", return_value=vault),
+ patch("anton.chat.DatasourceRegistry", return_value=registry),
+ patch("anton.chat._prompt_or_cancel", side_effect=_capture),
+ ):
+ # "PostgreeSQL" triggers fuzzy match against "PostgreSQL"
+ await _handle_connect_datasource(
+ console, session._scratchpads, session, datasource_name=None,
+ )
+
+ # Find the fuzzy-confirm label (contains "Use this datasource?")
+ fuzzy_labels = [lbl for lbl in captured_labels if "Use this datasource?" in lbl]
+ # If no fuzzy suggestions were generated, just verify the prompt constants are correct.
+ if fuzzy_labels:
+ lbl = fuzzy_labels[0]
+ assert "(y/n)" in lbl
+ assert "[y/n]" not in lbl
+
+ @pytest.mark.parametrize("label", [
+ "(y/n)",
+ "(reconnect/cancel)",
+ "(anton) Would you like to re-enter your credentials? (y/n)",
+ "(anton) Use this datasource? (y/n)",
+ "(anton) Do you have these available? (y/n/)",
+ "(anton) (reconnect/cancel)",
+ ])
+ def test_canonical_labels_no_bracket_style(self, label):
+ """None of the canonical prompt strings use the old bracket style."""
+ assert "[y/n]" not in label
+ assert "[reconnect/cancel]" not in label
+ assert "(y/n" in label or _PROMPT_RECONNECT_CANCEL in label