From e5c19d30af588ddfe794acfcce35ac648a531d8b Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 27 Mar 2026 16:26:37 +0100 Subject: [PATCH 1/8] Better tests execution of the snippet with predefined and custom datasources --- anton/chat.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index cc09911..67f4116 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2870,13 +2870,13 @@ async def _run_connection_test( if cell.error or (cell.stdout.strip() != "ok" and cell.stderr.strip()): error_text = cell.error or cell.stderr.strip() or cell.stdout.strip() - first_line = next( - (ln for ln in error_text.splitlines() if ln.strip()), error_text + last_line = next( + (ln for ln in reversed(error_text.splitlines()) if ln.strip()), error_text ) console.print() console.print("[anton.warning](anton)[/] ✗ Connection failed.") console.print() - console.print(f" Error: {first_line}") + console.print(f" Error: {last_line}") console.print() retry = ( Prompt.ask( From c5d5acb3557341ebc9f35761204d52ee94fc9221 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 27 Mar 2026 16:27:23 +0100 Subject: [PATCH 2/8] Remove unused code --- anton/chat.py | 62 ++++----------------------------------------------- 1 file changed, 4 insertions(+), 58 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index 67f4116..a2e0c05 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -3008,64 +3008,10 @@ async def _handle_connect_datasource( credentials[f.name] = value if engine_def.test_snippet: - while True: - console.print() - console.print("[anton.cyan](anton)[/] Got it. Testing connection…") - - # Temporarily save credentials so inject_env(flat=True) can load them, - # then restore all namespaced env vars in the finally block. - vault.save(edit_engine, edit_name, credentials) - vault.clear_ds_env() - vault.inject_env(edit_engine, edit_name, flat=True) - _register_secret_vars(engine_def) # flat names for scrubbing during test - try: - pad = await scratchpads.get_or_create("__datasource_test__") - await pad.reset() - if engine_def.pip: - await pad.install_packages([engine_def.pip]) - cell = await pad.execute(engine_def.test_snippet) - finally: - _restore_namespaced_env(vault) - - if cell.error or (cell.stdout.strip() != "ok" and cell.stderr.strip()): - error_text = ( - cell.error or cell.stderr.strip() or cell.stdout.strip() - ) - first_line = next( - (ln for ln in error_text.splitlines() if ln.strip()), error_text - ) - console.print() - console.print("[anton.warning](anton)[/] ✗ Connection failed.") - console.print() - console.print(f" Error: {first_line}") - console.print() - retry = ( - Prompt.ask( - "[anton.cyan](anton)[/] Would you like to re-enter your credentials? [y/n]", - console=console, - default="n", - ) - .strip() - .lower() - ) - if retry != "y": - return session - console.print() - for f in active_fields: - if not f.secret: - continue - value = Prompt.ask( - f"[anton.cyan](anton)[/] {f.name}", - password=True, - console=console, - default="", - ) - if value: - credentials[f.name] = value - continue - - console.print("[anton.success] ✓ Connected successfully![/]") - break + if not await _run_connection_test( + console, scratchpads, vault, engine_def, credentials, active_fields + ): + return session vault.save(edit_engine, edit_name, credentials) _restore_namespaced_env(vault) From ce9b96c19c33346bc7b41fcdc0d14fa5ac5336e4 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 27 Mar 2026 16:27:50 +0100 Subject: [PATCH 3/8] Include the method in the import --- tests/test_datasource.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_datasource.py b/tests/test_datasource.py index 3989cdf..052f42a 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -21,6 +21,7 @@ _handle_test_datasource, _register_secret_vars, _restore_namespaced_env, + _run_connection_test, _scrub_credentials, parse_connection_slug, ) From 2b9952be58d5b3912b9fb8c6c40ad6927fd51f0f Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 27 Mar 2026 16:37:24 +0100 Subject: [PATCH 4/8] Add some tests --- tests/test_datasource.py | 149 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) diff --git a/tests/test_datasource.py b/tests/test_datasource.py index 052f42a..021a0a6 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -2025,3 +2025,152 @@ async def test_custom_without_test_snippet_saves( conns = vault.list_connections() assert len(conns) == 1 pad.execute.assert_not_called() + + +class TestEditDatasourceWithTestSnippet: + """Tests for /edit path: test_snippet runs before vault.save, not after.""" + + OLD_CREDS = { + "host": "pg.example.com", + "port": "5432", + "database": "prod_db", + "user": "alice", + "password": "good-pass", + "schema": "", + } + + def _setup_pad(self, session, cell): + pad = AsyncMock() + pad.execute = AsyncMock(return_value=cell) + pad.reset = AsyncMock() + pad.install_packages = AsyncMock() + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + return pad + + @pytest.mark.asyncio + async def test_edit_failed_test_does_not_corrupt_vault( + self, vault_dir, registry, make_session, make_cell + ): + """edit with bad creds + test fails + user declines retry → original creds intact.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", self.OLD_CREDS) + + self._setup_pad(session, make_cell(stdout="", stderr="connection refused")) + + # Keep all non-secret fields; enter bad password; decline retry. + prompt_responses = iter([ + "", # host (keep existing) + "", # port + "", # database + "", # user + "bad-pass", # password (new, bad) + "", # schema + "n", # retry? + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch( + "rich.prompt.Prompt.ask", + side_effect=lambda *a, **kw: next(prompt_responses), + ), + ): + result = await _handle_connect_datasource( + console, session._scratchpads, session, + datasource_name="postgresql-prod_db", + ) + + saved = vault.load("postgresql", "prod_db") + assert saved is not None + assert saved.get("password") == "good-pass" + assert result._history == [] + + @pytest.mark.asyncio + async def test_edit_successful_test_persists_new_credentials( + self, vault_dir, registry, make_session, make_cell + ): + """edit with valid creds + test passes → new creds saved to vault.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", self.OLD_CREDS) + + self._setup_pad(session, make_cell(stdout="ok")) + + prompt_responses = iter([ + "", # host + "", # port + "", # database + "", # user + "new-pass", # password (updated) + "", # schema + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch( + "rich.prompt.Prompt.ask", + side_effect=lambda *a, **kw: next(prompt_responses), + ), + ): + result = await _handle_connect_datasource( + console, session._scratchpads, session, + datasource_name="postgresql-prod_db", + ) + + saved = vault.load("postgresql", "prod_db") + assert saved is not None + assert saved.get("password") == "new-pass" + assert result._history + + @pytest.mark.asyncio + async def test_connection_test_error_summary_uses_meaningful_line( + self, vault_dir, registry + ): + """Error display shows last non-empty line (exception msg), not traceback header.""" + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + traceback_text = ( + "Traceback (most recent call last):\n" + " File \"test.py\", line 3, in \n" + " conn = psycopg2.connect(host=os.environ['DS_HOST'])\n" + "psycopg2.OperationalError: could not connect to server\n" + ) + cell = MagicMock() + cell.stdout = "" + cell.stderr = traceback_text + cell.error = None + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=cell) + pad.reset = AsyncMock() + pad.install_packages = AsyncMock() + + scratchpads = AsyncMock() + scratchpads.get_or_create = AsyncMock(return_value=pad) + + engine_def = registry.get("postgresql") + credentials = { + "host": "bad-host", "port": "5432", + "database": "prod_db", "user": "alice", "password": "pw", + } + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", return_value="n"), + ): + result = await _run_connection_test( + console, scratchpads, vault, engine_def, credentials, + retry_fields=engine_def.fields, + ) + + assert result is False + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "psycopg2.OperationalError" in printed + assert "Traceback (most recent call last)" not in printed From 27bd57127f185e082c972fa13f1cef4f39584c04 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 27 Mar 2026 16:54:07 +0100 Subject: [PATCH 5/8] Test updates --- tests/test_datasource.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_datasource.py b/tests/test_datasource.py index 021a0a6..995d179 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -584,7 +584,7 @@ async def test_partial_save_on_n_answer(self, registry, vault_dir, make_session) conns = vault.list_connections() assert len(conns) == 1 assert conns[0]["engine"] == "postgresql" - assert conns[0]["name"].isdigit() + assert len(conns[0]["name"]) == 8 and all(c in "0123456789abcdef" for c in conns[0]["name"]) session._scratchpads.get_or_create.assert_not_called() @pytest.mark.asyncio From fe44a0bc5b50ef602f49f45fbce5c3a57aadb5aa Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 27 Mar 2026 19:15:42 +0100 Subject: [PATCH 6/8] Fix the connection issues with Esc and increase the version to 0.9.0 --- anton/__init__.py | 2 +- anton/chat.py | 141 ++++++++++++------- tests/test_datasource.py | 292 +++++++++++++++++++++++++-------------- 3 files changed, 279 insertions(+), 156 deletions(-) diff --git a/anton/__init__.py b/anton/__init__.py index 49e0fc1..3e2f46a 100644 --- a/anton/__init__.py +++ b/anton/__init__.py @@ -1 +1 @@ -__version__ = "0.7.0" +__version__ = "0.9.0" diff --git a/anton/chat.py b/anton/chat.py index fb76331..36ca6cf 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import concurrent.futures import json as _json import os import re as _re @@ -52,6 +53,10 @@ _YAML_BLOCK_RE, ) +from prompt_toolkit import PromptSession +from prompt_toolkit.formatted_text import HTML +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.styles import Style as PTStyle from rich.prompt import Confirm, Prompt if TYPE_CHECKING: @@ -78,6 +83,12 @@ "Only involve the user if the problem truly requires something only they can provide." ) +# ── Interactive prompt copy ─────────────────────────────────────────────────── +_PROMPT_YN = "(y/n)" +_PROMPT_RECONNECT_CANCEL = "(reconnect/cancel)" +_PROMPT_SELECT_Q = "(or q to cancel)" +_PROMPT_MEMORY_SAVE = "(y/n/pick numbers)" + class ChatSession: """Manages a multi-turn conversation with tool-call delegation.""" @@ -1710,6 +1721,62 @@ def _mask_secret(value: str, *, keep: int = 4) -> str: return f"{value[:keep]}...{value[-keep:]}" +def _prompt_or_cancel( + label: str, + *, + default: str = "", + password: bool = False, +) -> str | None: + """Prompt for free-text input; return None if the user presses Esc. + + Uses the same ESC-detection pattern as _setup_prompt() in cli.py: + a prompt_toolkit session with an Esc key binding that exits the session + and signals cancellation. Works from both sync and async contexts. + """ + _esc = False + bindings = KeyBindings() + + @bindings.add("escape") + @bindings.add("c-c") + def _on_esc(event): + nonlocal _esc + _esc = True + event.app.exit(result="") + + pt_style = PTStyle.from_dict({"bottom-toolbar": "noreverse nounderline bg:default"}) + + def _toolbar(): + return HTML("") + + suffix = f" ({default}): " if default else ": " + session: PromptSession[str] = PromptSession( + mouse_support=False, + bottom_toolbar=_toolbar, + style=pt_style, + key_bindings=bindings, + is_password=password, + ) + + try: + asyncio.get_running_loop() + in_async = True + except RuntimeError: + in_async = False + + if in_async: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + fut = pool.submit(session.prompt, f"{label}{suffix}") + result = fut.result() + else: + result = session.prompt(f"{label}{suffix}") + + if _esc: + return None + if not result and default: + return default + return result + + def _prompt_minds_api_key( console: Console, *, @@ -2617,26 +2684,23 @@ async def _handle_add_custom_datasource( """Ask for the tool name, use the LLM to identify required fields, then collect credentials.""" console.print() - preamble = "[anton.cyan](anton)[/] " if name: tool_name = name name_context = f"'{name}' isn't in my built-in list.\n " else: - tool_name = Prompt.ask( - f"{preamble}What is the name of the tool or service?", - console=console, + tool_name = _prompt_or_cancel( + "(anton) What is the name of the tool or service?", ) - if not tool_name.strip(): + if not tool_name or not tool_name.strip(): return None tool_name = tool_name.strip() name_context = "" - user_answer = Prompt.ask( - f"{preamble}{name_context}How do you authenticate with it? " + user_answer = _prompt_or_cancel( + f"(anton) {name_context}How do you authenticate with it? " "Describe what credentials you have (don't paste actual values)", - console=console, ) - if not user_answer.strip(): + if not user_answer or not user_answer.strip(): return None console.print() @@ -2878,16 +2942,10 @@ async def _run_connection_test( console.print() console.print(f" Error: {last_line}") console.print() - retry = ( - Prompt.ask( - "[anton.cyan](anton)[/] Would you like to re-enter your credentials? [y/n]", - console=console, - default="n", - ) - .strip() - .lower() + retry = _prompt_or_cancel( + "(anton) Would you like to re-enter your credentials? (y/n)", ) - if retry != "y": + if retry is None or retry.strip().lower() != "y": return False console.print() for f in retry_fields: @@ -3050,10 +3108,11 @@ async def _handle_connect_datasource( for i, e in enumerate(all_engines, 1): console.print(f" [bold]{i:>2}.[/bold] {e.display_name}") console.print() - answer = Prompt.ask( - "[anton.cyan](anton)[/] Enter a number or type a name", - console=console, + answer = _prompt_or_cancel( + "(anton) Enter a number or type a name", ) + if answer is None: + return session stripped_answer = answer.strip() known_slugs = {f"{c['engine']}-{c['name']}": c for c in vault.list_connections()} @@ -3140,16 +3199,10 @@ async def _handle_connect_datasource( console.print( f'[anton.cyan](anton)[/] Did you mean [bold]"{suggestion.display_name}"[/bold]?' ) - confirm = ( - Prompt.ask( - "[anton.cyan](anton)[/] [y/n]", - console=console, - default="n", - ) - .strip() - .lower() + confirm = _prompt_or_cancel( + "(anton) Use this datasource? (y/n)", ) - if confirm == "y": + if confirm is not None and confirm.strip().lower() == "y": engine_def = suggestion break @@ -3245,14 +3298,12 @@ async def _handle_connect_datasource( console.print() - mode_answer = ( - Prompt.ask( - "[anton.cyan](anton)[/] Do you have these available? [y/n/]", - console=console, - ) - .strip() - .lower() + mode_answer = _prompt_or_cancel( + "(anton) Do you have these available? (y/n/)", ) + if mode_answer is None: + return session + mode_answer = mode_answer.strip().lower() if mode_answer == "n": console.print() @@ -3331,16 +3382,10 @@ async def _handle_connect_datasource( f'[anton.warning](anton)[/] A connection [bold]"{slug}"[/bold] already exists.' ) console.print() - choice = ( - Prompt.ask( - "[anton.cyan](anton)[/] [reconnect/cancel]", - console=console, - default="cancel", - ) - .strip() - .lower() + choice = _prompt_or_cancel( + f"(anton) {_PROMPT_RECONNECT_CANCEL}", ) - if choice != "reconnect": + if choice is None or choice.strip().lower() != "reconnect": console.print("[anton.muted]Cancelled.[/]") console.print() return session @@ -3820,10 +3865,6 @@ async def _chat_loop( toolbar = {"stats": "", "status": ""} display = StreamDisplay(console, toolbar=toolbar) - from prompt_toolkit import PromptSession - from prompt_toolkit.formatted_text import HTML - from prompt_toolkit.styles import Style as PTStyle - def _bottom_toolbar(): stats = toolbar["stats"] status = toolbar["status"] diff --git a/tests/test_datasource.py b/tests/test_datasource.py index b6ae852..364d915 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -13,6 +13,7 @@ ChatSession, _DS_KNOWN_VARS, _DS_SECRET_VARS, + _PROMPT_RECONNECT_CANCEL, _build_datasource_context, _handle_add_custom_datasource, _handle_connect_datasource, @@ -559,7 +560,7 @@ async def test_unknown_engine_returns_early(self, registry, vault_dir, make_sess with ( patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)), patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", return_value="MySQL"), + patch("anton.chat._prompt_or_cancel", return_value="MySQL"), ): result = await _handle_connect_datasource(console, session._scratchpads, session) @@ -572,12 +573,14 @@ async def test_partial_save_on_n_answer(self, registry, vault_dir, make_session) session = make_session() console = MagicMock() vault = DataVault(vault_dir=vault_dir) - prompt_responses = iter(["PostgreSQL", "n", "db.example.com", "", "", "", "", ""]) + poc_responses = iter(["PostgreSQL", "n"]) + pa_responses = iter(["db.example.com", "", "", "", "", ""]) with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), ): await _handle_connect_datasource(console, session._scratchpads, session) @@ -601,15 +604,14 @@ async def test_successful_connection_saves_and_injects_history( pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) - prompt_responses = iter([ - "PostgreSQL", "y", - "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "", - ]) + poc_responses = iter(["PostgreSQL", "y"]) + pa_responses = iter(["db.example.com", "5432", "prod_db", "alice", "s3cr3t", ""]) with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), ): result = await _handle_connect_datasource(console, session._scratchpads, session) @@ -638,17 +640,17 @@ async def test_failed_test_offers_retry(self, registry, vault_dir, make_session, ]) session._scratchpads.get_or_create = AsyncMock(return_value=pad) - prompt_responses = iter([ - "PostgreSQL", "y", + poc_responses = iter(["PostgreSQL", "y", "y"]) # engine, do-you-have, retry? + pa_responses = iter([ "db.example.com", "5432", "prod_db", "alice", "wrongpassword", "", - "y", # retry? "correctpassword", ]) with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), ): await _handle_connect_datasource(console, session._scratchpads, session) @@ -671,16 +673,14 @@ async def test_failed_test_no_retry_returns_without_saving( pad.execute = AsyncMock(return_value=make_cell(stdout="", error="connection refused")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) - prompt_responses = iter([ - "PostgreSQL", "y", - "db.example.com", "5432", "prod_db", "alice", "badpass", "", - "n", - ]) + poc_responses = iter(["PostgreSQL", "y", "n"]) # engine, do-you-have, retry? + pa_responses = iter(["db.example.com", "5432", "prod_db", "alice", "badpass", ""]) with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), ): result = await _handle_connect_datasource(console, session._scratchpads, session) @@ -700,15 +700,14 @@ async def test_ds_env_injected_after_successful_connect( pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) - prompt_responses = iter([ - "PostgreSQL", "y", - "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "", - ]) + poc_responses = iter(["PostgreSQL", "y"]) + pa_responses = iter(["db.example.com", "5432", "prod_db", "alice", "s3cr3t", ""]) with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), ): await _handle_connect_datasource(console, session._scratchpads, session) @@ -728,12 +727,14 @@ async def test_auth_method_choice_selects_fields( pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) - prompt_responses = iter(["HubSpot", "1", "y", "pat-na1-abc123"]) + poc_responses = iter(["HubSpot", "y"]) # engine, do-you-have + pa_responses = iter(["1", "pat-na1-abc123"]) # auth method choice, field value with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), ): await _handle_connect_datasource(console, session._scratchpads, session) @@ -759,15 +760,14 @@ async def test_selective_field_collection( pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) - prompt_responses = iter([ - "PostgreSQL", "host,user,password", - "db.example.com", "alice", "s3cr3t", - ]) + poc_responses = iter(["PostgreSQL", "host,user,password"]) # engine, do-you-have (list params) + pa_responses = iter(["db.example.com", "alice", "s3cr3t"]) # field values with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), ): await _handle_connect_datasource(console, session._scratchpads, session) @@ -845,15 +845,14 @@ async def test_register_and_scrub_on_connect(self, registry, vault_dir, monkeypa session._scratchpads.get_or_create = AsyncMock(return_value=pad) secret_pw = "supersecretpassword999" - prompt_responses = iter([ - "PostgreSQL", "y", - "db.host.com", "5432", "mydb", "alice", secret_pw, "public", - ]) + poc_responses = iter(["PostgreSQL", "y"]) + pa_responses = iter(["db.host.com", "5432", "mydb", "alice", secret_pw, "public"]) with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), ): await _handle_connect_datasource(MagicMock(), session._scratchpads, session) @@ -1349,15 +1348,14 @@ async def test_connect_clears_previous_ds_vars( pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) - prompt_responses = iter([ - "PostgreSQL", "y", - "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "", - ]) + poc_responses = iter(["PostgreSQL", "y"]) + pa_responses = iter(["db.example.com", "5432", "prod_db", "alice", "s3cr3t", ""]) with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), ): await _handle_connect_datasource(console, session._scratchpads, session) @@ -1426,7 +1424,7 @@ async def test_edit_data_source_no_arg_safe(self, vault_dir, registry, make_sess with ( patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)), patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", return_value="UnknownEngine"), + patch("anton.chat._prompt_or_cancel", return_value="UnknownEngine"), ): updated = await _handle_connect_datasource( console, session._scratchpads, session, @@ -1731,14 +1729,9 @@ async def test_missing_required_non_secret_field_prompts_user( console = MagicMock() registry = self._make_registry(tmp_path) - # first response: initial auth question; second: the host field prompt - prompt_responses = iter(["I want to connect to mydb", "localhost"]) - with ( - patch( - "rich.prompt.Prompt.ask", - side_effect=lambda *a, **kw: next(prompt_responses), - ), + patch("anton.chat._prompt_or_cancel", return_value="I want to connect to mydb"), + patch("rich.prompt.Prompt.ask", return_value="localhost"), patch("anton.chat.Path") as mock_path_cls, ): self._mock_ds_path(mock_path_cls, tmp_path) @@ -1768,12 +1761,11 @@ async def test_missing_required_secret_field_prompts_user( password_calls = [] def fake_prompt(*args, **kwargs): - if kwargs.get("password"): - password_calls.append(kwargs) - return "mysecret" - return "I want to connect" + password_calls.append(kwargs) + return "mysecret" with ( + patch("anton.chat._prompt_or_cancel", return_value="I want to connect"), patch("rich.prompt.Prompt.ask", side_effect=fake_prompt), patch("anton.chat.Path") as mock_path_cls, ): @@ -1806,14 +1798,9 @@ async def test_incomplete_custom_datasource_not_saved( console = MagicMock() registry = self._make_registry(tmp_path) - # User presses Enter (empty) for every prompt - prompt_responses = iter(["I want to connect", "", ""]) - with ( - patch( - "rich.prompt.Prompt.ask", - side_effect=lambda *a, **kw: next(prompt_responses), - ), + patch("anton.chat._prompt_or_cancel", return_value="I want to connect"), + patch("rich.prompt.Prompt.ask", return_value=""), patch("anton.chat.Path") as mock_path_cls, ): self._mock_ds_path(mock_path_cls, tmp_path) @@ -1884,17 +1871,14 @@ async def test_custom_with_test_snippet_success( test_snippet="print('ok')", )) - prompt_responses = iter([ - "0", # choose custom - "My API Service", # tool name - "I have an API key", # auth description - "my_secret_key", # api_key (secret prompt) - ]) + poc_responses = iter(["0", "My API Service", "I have an API key"]) # engine sel, tool name, auth desc + pa_responses = iter(["my_secret_key"]) with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), patch("anton.chat.Path") as mock_path_cls, ): self._mock_ds_path(mock_path_cls, tmp_path) @@ -1926,18 +1910,14 @@ async def test_custom_with_test_snippet_fail_no_retry( test_snippet="print('ok')", )) - prompt_responses = iter([ - "0", - "My API Service", - "I have an API key", - "bad_key", # api_key - "n", # retry? - ]) + poc_responses = iter(["0", "My API Service", "I have an API key", "n"]) # engine sel, tool name, auth desc, retry? + pa_responses = iter(["bad_key"]) with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), patch("anton.chat.Path") as mock_path_cls, ): self._mock_ds_path(mock_path_cls, tmp_path) @@ -1966,19 +1946,14 @@ async def test_custom_with_test_snippet_fail_retry_success( test_snippet="print('ok')", )) - prompt_responses = iter([ - "0", - "My API Service", - "I have an API key", - "bad_key", # api_key first attempt - "y", # retry? - "good_key", # api_key retry - ]) + poc_responses = iter(["0", "My API Service", "I have an API key", "y"]) # engine sel, tool name, auth desc, retry? + pa_responses = iter(["bad_key", "good_key"]) with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), patch("anton.chat.Path") as mock_path_cls, ): self._mock_ds_path(mock_path_cls, tmp_path) @@ -2007,17 +1982,14 @@ async def test_custom_without_test_snippet_saves( test_snippet="", )) - prompt_responses = iter([ - "0", - "My API Service", - "I have an API key", - "my_key", # api_key - ]) + poc_responses = iter(["0", "My API Service", "I have an API key"]) # engine sel, tool name, auth desc + pa_responses = iter(["my_key"]) with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), patch("anton.chat.Path") as mock_path_cls, ): self._mock_ds_path(mock_path_cls, tmp_path) @@ -2061,23 +2033,14 @@ async def test_edit_failed_test_does_not_corrupt_vault( self._setup_pad(session, make_cell(stdout="", stderr="connection refused")) # Keep all non-secret fields; enter bad password; decline retry. - prompt_responses = iter([ - "", # host (keep existing) - "", # port - "", # database - "", # user - "bad-pass", # password (new, bad) - "", # schema - "n", # retry? - ]) + poc_responses = iter(["n"]) # retry? + pa_responses = iter(["", "", "", "", "bad-pass", ""]) # field values with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=registry), - patch( - "rich.prompt.Prompt.ask", - side_effect=lambda *a, **kw: next(prompt_responses), - ), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_responses)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(pa_responses)), ): result = await _handle_connect_datasource( console, session._scratchpads, session, @@ -2164,7 +2127,7 @@ async def test_connection_test_error_summary_uses_meaningful_line( with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", return_value="n"), + patch("anton.chat._prompt_or_cancel", return_value="n"), ): result = await _run_connection_test( console, scratchpads, vault, engine_def, credentials, @@ -2174,4 +2137,123 @@ async def test_connection_test_error_summary_uses_meaningful_line( assert result is False printed = " ".join(str(c) for c in console.print.call_args_list) assert "psycopg2.OperationalError" in printed - assert "Traceback (most recent call last)" not in printed + + +# ───────────────────────────────────────────────────────────────────────────── +# Prompt copy consistency +# ───────────────────────────────────────────────────────────────────────────── + + +class TestPromptCopyConsistency: + """Verify that interactive prompts use (y/n) style and Esc cancels safely.""" + + @pytest.mark.asyncio + async def test_esc_on_engine_selection_returns_session_unchanged( + self, registry, vault_dir, make_session + ): + """Pressing Esc on the engine-selection prompt returns the session with no vault writes.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("anton.chat._prompt_or_cancel", return_value=None), + ): + result = await _handle_connect_datasource(console, session._scratchpads, session) + + assert result is session + assert vault.list_connections() == [] + + @pytest.mark.asyncio + async def test_esc_on_retry_does_not_save(self, registry, vault_dir, make_session, make_cell): + """Pressing Esc at the retry prompt makes _run_connection_test return False.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="", error="bad creds")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("anton.chat._prompt_or_cancel", return_value=None), + ): + engine_def = registry.get("postgresql") + credentials = {"host": "h", "port": "5432", "database": "d", "user": "u", "password": "p"} + result = await _run_connection_test( + console, session._scratchpads, vault, engine_def, credentials, + retry_fields=engine_def.fields, + ) + + assert result is False + assert vault.list_connections() == [] + + @pytest.mark.asyncio + async def test_esc_on_do_you_have_these_returns_session( + self, registry, vault_dir, make_session + ): + """Pressing Esc after engine selection (on 'do you have these?') returns session.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + poc_calls = iter(["PostgreSQL", None]) # engine selected, then Esc + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("anton.chat._prompt_or_cancel", side_effect=lambda *a, **kw: next(poc_calls)), + ): + result = await _handle_connect_datasource(console, session._scratchpads, session) + + assert result is session + assert vault.list_connections() == [] + + @pytest.mark.asyncio + async def test_fuzzy_match_prompt_has_context_text(self, registry, vault_dir, make_session): + """The fuzzy-match confirmation prompt includes context text and uses (y/n).""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + captured_labels: list[str] = [] + + def _capture(label, **kw): + captured_labels.append(label) + return None # Esc on every prompt to bail out + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("anton.chat._prompt_or_cancel", side_effect=_capture), + ): + # "PostgreeSQL" triggers fuzzy match against "PostgreSQL" + await _handle_connect_datasource( + console, session._scratchpads, session, datasource_name=None, + ) + + # Find the fuzzy-confirm label (contains "Use this datasource?") + fuzzy_labels = [lbl for lbl in captured_labels if "Use this datasource?" in lbl] + # If no fuzzy suggestions were generated, just verify the prompt constants are correct. + if fuzzy_labels: + lbl = fuzzy_labels[0] + assert "(y/n)" in lbl + assert "[y/n]" not in lbl + + @pytest.mark.parametrize("label", [ + "(y/n)", + "(reconnect/cancel)", + "(anton) Would you like to re-enter your credentials? (y/n)", + "(anton) Use this datasource? (y/n)", + "(anton) Do you have these available? (y/n/)", + "(anton) (reconnect/cancel)", + ]) + def test_canonical_labels_no_bracket_style(self, label): + """None of the canonical prompt strings use the old bracket style.""" + assert "[y/n]" not in label + assert "[reconnect/cancel]" not in label + assert "(y/n" in label or _PROMPT_RECONNECT_CANCEL in label From b8f3292b02657530d147cb45708136e4b11ef766 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 27 Mar 2026 19:45:16 +0100 Subject: [PATCH 7/8] Improve the connection menu --- anton/chat.py | 78 +++++++++++++++++++++++++++++++----- anton/config/datasources.md | 18 +++++++++ anton/datasource_registry.py | 13 +++++- 3 files changed, 97 insertions(+), 12 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index 36ca6cf..59b6223 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -3096,26 +3096,84 @@ async def _handle_connect_datasource( console.print() all_engines = registry.all_engines() + popular_engines = [e for e in all_engines if e.popular and not e.custom] + other_engines = [e for e in all_engines if not e.popular and not e.custom] + custom_engines = [e for e in all_engines if e.custom] + display_engines = popular_engines + other_engines + custom_engines + + def _print_sections() -> None: + console.print( + "[anton.cyan](anton)[/] Choose a data source:\n" + ) + console.print(" [bold] Primary") + console.print( + " [bold] 0.[/bold] Custom datasource" + " (connect anything via API, SQL, or MCP)\n" + ) + if popular_engines: + console.print(" [bold] Most popular") + for i, e in enumerate(popular_engines, 1): + console.print(f" [bold]{i:>2}.[/bold] {e.display_name}") + console.print() + if other_engines: + start = len(popular_engines) + 1 + console.print(" [bold] Other connectors") + for i, e in enumerate(other_engines[:3], start): + console.print(f" [bold]{i:>2}.[/bold] {e.display_name}") + if len(other_engines) > 3: + console.print( + f" [anton.muted] … and" + f" {len(other_engines) - 3} more" + f" (type 'all' to see all)[/]" + ) + console.print() + if custom_engines: + start = len(popular_engines) + len(other_engines) + 1 + console.print(" [bold] Custom") + for i, e in enumerate(custom_engines, start): + console.print(f" [bold]{i:>2}.[/bold] {e.display_name}") + console.print() + + def _print_all() -> None: + console.print( + "[anton.cyan](anton)[/] All data sources (★ = popular):\n" + ) + console.print(" [bold] Primary") + console.print( + " [bold] 0.[/bold] Custom datasource" + " (connect anything via API, SQL, or MCP)\n" + ) + for i, e in enumerate(display_engines, 1): + star = " ★" if e.popular else "" + console.print(f" [bold]{i:>2}.[/bold] {e.display_name}{star}") + console.print() + if prefill: answer = prefill else: + _print_sections() console.print( - "[anton.cyan](anton)[/] Choose a data source:\n" + " [anton.muted]Type 'all' to see every datasource.[/]" ) - console.print(" [bold] Primary") - console.print(" [bold] 0.[/bold] Custom datasource (connect anything via API, SQL, or MCP)\n") - console.print(" [bold] Predefined") - for i, e in enumerate(all_engines, 1): - console.print(f" [bold]{i:>2}.[/bold] {e.display_name}") console.print() answer = _prompt_or_cancel( "(anton) Enter a number or type a name", ) if answer is None: return session + if answer.strip().lower() == "all": + console.print() + _print_all() + answer = _prompt_or_cancel( + "(anton) Enter a number or type a name", + ) + if answer is None: + return session stripped_answer = answer.strip() - known_slugs = {f"{c['engine']}-{c['name']}": c for c in vault.list_connections()} + known_slugs = { + f"{c['engine']}-{c['name']}": c for c in vault.list_connections() + } if stripped_answer in known_slugs: conn = known_slugs[stripped_answer] _restore_namespaced_env(vault) @@ -3149,12 +3207,12 @@ async def _handle_connect_datasource( pick_num = int(stripped_answer) if pick_num == 0: custom_source = True - elif 1 <= pick_num <= len(all_engines): - engine_def = all_engines[pick_num - 1] + elif 1 <= pick_num <= len(display_engines): + engine_def = display_engines[pick_num - 1] else: console.print( f"[anton.warning](anton)[/] '{stripped_answer}' is out of range. " - f"Please enter 0–{len(all_engines)}.[/]" + f"Please enter 0–{len(display_engines)}.[/]" ) console.print() return session diff --git a/anton/config/datasources.md b/anton/config/datasources.md index 2cafc88..bd4773d 100644 --- a/anton/config/datasources.md +++ b/anton/config/datasources.md @@ -16,6 +16,7 @@ engine: postgres display_name: PostgreSQL pip: psycopg2-binary name_from: database +popular: true fields: - { name: host, required: true, secret: false, description: "hostname or IP of your database server" } - { name: port, required: true, secret: false, description: "port number", default: "5432" } @@ -47,6 +48,7 @@ engine: mysql display_name: MySQL pip: mysql-connector-python name_from: database +popular: true fields: - { name: host, required: true, secret: false, description: "hostname or IP of your MySQL server" } - { name: port, required: true, secret: false, description: "port number", default: "3306" } @@ -78,6 +80,7 @@ test_snippet: | engine: snowflake display_name: Snowflake pip: snowflake-connector-python +popular: true name_from: [account, database] auth_method: choice auth_methods: @@ -127,6 +130,7 @@ Format is either `-` or `.. engine: bigquery display_name: Google BigQuery pip: google-cloud-bigquery +popular: true name_from: [project_id, dataset] fields: - { name: project_id, required: true, secret: false, description: "GCP project ID containing your BigQuery datasets" } @@ -162,6 +166,7 @@ Keys → Add Key → JSON. Grant the account `BigQuery Data Viewer` + `BigQuery engine: mssql display_name: Microsoft SQL Server pip: pymssql +popular: true name_from: database fields: - { name: host, required: true, secret: false, description: "hostname or IP of the SQL Server (for Azure use server field instead)" } @@ -195,6 +200,7 @@ For Windows Authentication omit user/password and ensure pymssql is built with K engine: redshift display_name: Amazon Redshift pip: psycopg2-binary +popular: true name_from: [host, database] fields: - { name: host, required: true, secret: false, description: "Redshift cluster endpoint (e.g. mycluster.abc123.us-east-1.redshift.amazonaws.com)" } @@ -229,6 +235,7 @@ Redshift → Clusters → your cluster → Endpoint (omit the port suffix). engine: databricks display_name: Databricks pip: databricks-sql-connector +popular: true name_from: [server_hostname, catalog] fields: - { name: server_hostname, required: true, secret: false, description: "server hostname for the cluster or SQL warehouse (from JDBC/ODBC connection string)" } @@ -262,6 +269,7 @@ HTTP path and server hostname: SQL Warehouses → your warehouse → Connection engine: mariadb display_name: MariaDB pip: mysql-connector-python +popular: true name_from: database fields: - { name: host, required: true, secret: false, description: "hostname or IP of your MariaDB server" } @@ -296,6 +304,7 @@ MariaDB is wire-compatible with MySQL, so the mysql-connector-python driver work engine: hubspot display_name: HubSpot pip: hubspot-api-client +popular: true auth_method: choice auth_methods: - name: pat @@ -338,6 +347,7 @@ For OAuth2: collect client_id and client_secret, then use the scratchpad to: engine: oracle_database display_name: Oracle Database pip: oracledb +popular: true name_from: [host, service_name] fields: - { name: user, required: true, secret: false, description: "Oracle database username" } @@ -371,6 +381,7 @@ Set `auth_mode` to `SYSDBA` or `SYSOPER` for privileged connections. engine: duckdb display_name: DuckDB pip: duckdb +popular: false name_from: database fields: - { name: database, required: false, secret: false, description: "path to DuckDB database file; omit or use :memory: for in-memory database", default: ":memory:" } @@ -400,6 +411,7 @@ the access token. For local files, provide the path to a `.duckdb` file. engine: pgvector display_name: pgvector pip: pgvector psycopg2-binary +popular: false name_from: database fields: - { name: host, required: true, secret: false, description: "hostname or IP of your PostgreSQL server with pgvector extension" } @@ -437,6 +449,7 @@ Managed options: Supabase, Neon, and AWS RDS for PostgreSQL all support pgvector engine: chromadb display_name: ChromaDB pip: chromadb +popular: false name_from: host fields: - { name: host, required: true, secret: false, description: "ChromaDB server host for HTTP client mode (omit for local in-process mode)" } @@ -471,6 +484,7 @@ or ephemeral in-memory. For production, run `chroma run` to start the HTTP serve engine: salesforce display_name: Salesforce pip: salesforce_api +popular: true name_from: username fields: - { name: username, required: true, secret: false, description: "Salesforce account username (email)" } @@ -502,6 +516,7 @@ Enable OAuth, add callback URL, select scopes (api, refresh_token). engine: shopify display_name: Shopify pip: ShopifyAPI +popular: true name_from: shop_url fields: - { name: shop_url, required: true, secret: false, description: "your Shopify store URL (e.g. mystore.myshopify.com)" } @@ -530,6 +545,7 @@ Grant required API permissions (read_products, read_orders, etc.) then install t engine: netsuite display_name: NetSuite pip: requests-oauthlib>=1.3.1 +popular: false name_from: account_id fields: - { name: account_id, required: true, secret: false, description: "NetSuite account/realm ID (e.g. 123456_SB1)" } @@ -570,6 +586,7 @@ Setup → Users/Roles → Access Tokens. The account ID can be found in Setup engine: bigcommerce display_name: Big Commerce pip: httpx +popular: false name_from: store_hash fields: - { name: api_base, required: true, secret: false, description: "Base URL of the BigCommerce API (e.g. https://api.bigcommerce.com/stores/0fh0fh0fh0/v3/)" } @@ -629,6 +646,7 @@ and self-hosted PostgreSQL with the TimescaleDB extension installed. engine: email display_name: Email name_from: email +popular: false fields: - { name: email, required: true, secret: false, description: "email address to connect" } - { name: password, required: true, secret: true, description: "email account password or app-specific password" } diff --git a/anton/datasource_registry.py b/anton/datasource_registry.py index 5777856..e932807 100644 --- a/anton/datasource_registry.py +++ b/anton/datasource_registry.py @@ -37,6 +37,9 @@ class DatasourceEngine: auth_method: str = "" auth_methods: list[AuthMethod] = field(default_factory=list) test_snippet: str = "" + popular: bool = False + # True for engines defined in ~/.anton/datasources.md + custom: bool = False # Matches a level-2 heading followed by a ```yaml fenced block. @@ -63,7 +66,9 @@ def _parse_fields(raw: list) -> list[DatasourceField]: return result -def _parse_file(path: Path) -> dict[str, DatasourceEngine]: +def _parse_file( + path: Path, *, custom: bool = False +) -> dict[str, DatasourceEngine]: """Extract engine definitions from a datasources.md file.""" if not path.is_file(): return {} @@ -108,6 +113,8 @@ def _parse_file(path: Path) -> dict[str, DatasourceEngine]: auth_method=str(data.get("auth_method", "")), auth_methods=auth_methods, test_snippet=str(data.get("test_snippet", "")), + popular=bool(data.get("popular", False)), + custom=custom, ) return engines @@ -125,7 +132,9 @@ def __init__(self) -> None: def _load(self) -> None: self._engines = _parse_file(self._BUILTIN_PATH) - for slug, engine in _parse_file(self._USER_PATH).items(): + for slug, engine in _parse_file( + self._USER_PATH, custom=True + ).items(): self._engines[slug] = engine def reload(self) -> None: From 1147d990f02caed7ea242fff8318e52bc10d834c Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 27 Mar 2026 19:49:32 +0100 Subject: [PATCH 8/8] Clean custom registry --- anton/chat.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/anton/chat.py b/anton/chat.py index 59b6223..2bc415b 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -3552,6 +3552,19 @@ def _handle_remove_data_source(console: Console, slug: str) -> None: ): vault.delete(engine, name) _restore_namespaced_env(vault) + engine_def = registry.get(engine) + if engine_def is not None and engine_def.custom: + remaining = [ + c for c in vault.list_connections() if c["engine"] == engine + ] + if not remaining: + user_path = DatasourceRegistry._USER_PATH + if user_path.is_file(): + updated = _remove_engine_block( + user_path.read_text(encoding="utf-8"), engine + ) + user_path.write_text(updated, encoding="utf-8") + registry.reload() console.print(f"[anton.success]Removed {slug}.[/]") else: console.print("[anton.muted]Cancelled.[/]")