diff --git a/dev_env/trino/test_trino.py b/dev_env/trino/test_trino.py index e0684e9..2efb91e 100644 --- a/dev_env/trino/test_trino.py +++ b/dev_env/trino/test_trino.py @@ -4,7 +4,6 @@ env_config: Dict[str, Any] = combine_env_configs() - client = TrinoConnector( host=env_config["TRINO_HOST"], port=env_config["TRINO_PORT"], @@ -12,9 +11,17 @@ ) # Find +print(client.query("SHOW CATALOGS")) print(client.query("SELECT * FROM pg.public.employees")) +mylist = [{'name':'Pat','department':'Facilities MGMT'}] +print(client.bulk_insert('pg.public.employees',mylist)) -mylist = [{'name':'Pat','department':'Facilities MGMT'}] +client_https = TrinoConnector( + host=env_config["TRINO_HOST_2"], + port=env_config["TRINO_PORT_2"], + user=env_config["TRINO_USER_2"], + http_scheme='https' +) -print(client.bulk_insert('pg.public.employees',mylist)) \ No newline at end of file +print(client_https.query("SHOW CATALOGS")) \ No newline at end of file diff --git a/src/pyapiary/dbms_connectors/trino.py b/src/pyapiary/dbms_connectors/trino.py index 2107ade..c65cced 100644 --- a/src/pyapiary/dbms_connectors/trino.py +++ b/src/pyapiary/dbms_connectors/trino.py @@ -2,13 +2,9 @@ from typing import List, Dict, Any class TrinoConnector: - def __init__(self, host, port, user, catalog=None, schema=None): + def __init__(self, **kwargs): self.conn = connect( - host=host, - port=port, - user=user, - catalog=catalog, - schema=schema + **kwargs ) def query(self, query_str): diff --git a/src/pyapiary/tests/test_trino/test_unit_trino.py b/src/pyapiary/tests/test_trino/test_unit_trino.py index fb5113d..c448da5 100644 --- a/src/pyapiary/tests/test_trino/test_unit_trino.py +++ b/src/pyapiary/tests/test_trino/test_unit_trino.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, patch from pyapiary.dbms_connectors.trino import TrinoConnector @@ -15,11 +15,12 @@ def mock_connect(mocker): def connector(mock_connect): """Return a TrinoConnector backed by a mocked connection.""" return TrinoConnector( - host="localhost", - port=8080, + host="trino.trashcollector.dev", + port=443, user="test_user", catalog="hive", schema="default", + http_scheme="https", ) @@ -28,31 +29,33 @@ def connector(mock_connect): # --------------------------------------------------------------------------- class TestInit: - def test_connect_called_with_correct_args(self, mocker): + def test_connect_called_with_provided_kwargs(self, mocker): mock_connect = mocker.patch("pyapiary.dbms_connectors.trino.connect") - TrinoConnector(host="myhost", port=9090, user="alice", catalog="iceberg", schema="raw") + TrinoConnector(host="myhost", port=443, user="alice", http_scheme="https") mock_connect.assert_called_once_with( host="myhost", - port=9090, + port=443, user="alice", - catalog="iceberg", - schema="raw", + http_scheme="https", ) - def test_connect_called_without_optional_args(self, mocker): + def test_connect_called_with_minimal_kwargs(self, mocker): mock_connect = mocker.patch("pyapiary.dbms_connectors.trino.connect") - TrinoConnector(host="myhost", port=9090, user="alice") - mock_connect.assert_called_once_with( - host="myhost", - port=9090, - user="alice", - catalog=None, - schema=None, - ) + TrinoConnector(host="myhost", port=8080, user="alice") + mock_connect.assert_called_once_with(host="myhost", port=8080, user="alice") def test_conn_attribute_set(self, mock_connect, connector): assert connector.conn is mock_connect + def test_arbitrary_kwargs_forwarded(self, mocker): + """Any kwarg the trino client supports should be forwarded as-is.""" + mock_connect = mocker.patch("pyapiary.dbms_connectors.trino.connect") + TrinoConnector(host="h", port=443, user="u", http_scheme="https", + verify=False, session_properties={"query_max_run_time": "1h"}) + _, call_kwargs = mock_connect.call_args + assert call_kwargs["verify"] is False + assert call_kwargs["session_properties"] == {"query_max_run_time": "1h"} + # --------------------------------------------------------------------------- # query() @@ -108,7 +111,17 @@ def test_propagates_execute_exception(self, mock_connect, connector): mock_connect.cursor.return_value.__enter__.return_value = mock_cursor with pytest.raises(RuntimeError, match="syntax error"): - connector.query("SELECT bad syntax %%") + connector.query("SELECT bad %%") + + def test_show_catalogs(self, mock_connect, connector): + mock_cursor = MagicMock() + mock_cursor.description = [("Catalog",)] + mock_cursor.fetchall.return_value = [("hive",), ("iceberg",), ("tpch",)] + mock_connect.cursor.return_value.__enter__.return_value = mock_cursor + + result = connector.query("SHOW CATALOGS") + + assert result == [("hive",), ("iceberg",), ("tpch",)] # --------------------------------------------------------------------------- @@ -123,27 +136,18 @@ def test_inserts_single_row(self, mock_connect, connector): result = connector.bulk_insert("my_table", [{"id": 1, "name": "alice"}]) expected_query = "INSERT INTO my_table (id, name) VALUES (?, ?)" - mock_cursor.executemany.assert_called_once_with( - expected_query, [(1, "alice")] - ) + mock_cursor.executemany.assert_called_once_with(expected_query, [(1, "alice")]) assert result is True def test_inserts_multiple_rows(self, mock_connect, connector): mock_cursor = MagicMock() mock_connect.cursor.return_value.__enter__.return_value = mock_cursor - data = [ - {"id": 1, "val": "a"}, - {"id": 2, "val": "b"}, - {"id": 3, "val": "c"}, - ] + data = [{"id": 1, "val": "a"}, {"id": 2, "val": "b"}, {"id": 3, "val": "c"}] result = connector.bulk_insert("my_table", data) - expected_values = [(1, "a"), (2, "b"), (3, "c")] - _, call_values = mock_cursor.executemany.call_args - # positional args actual_values = mock_cursor.executemany.call_args[0][1] - assert actual_values == expected_values + assert actual_values == [(1, "a"), (2, "b"), (3, "c")] assert result is True def test_returns_none_for_empty_list(self, mock_connect, connector, capsys): @@ -165,10 +169,10 @@ def test_query_string_uses_correct_table_name(self, mock_connect, connector): mock_cursor = MagicMock() mock_connect.cursor.return_value.__enter__.return_value = mock_cursor - connector.bulk_insert("schema.target_table", [{"x": 99}]) + connector.bulk_insert("hive.default.target_table", [{"x": 99}]) actual_query = mock_cursor.executemany.call_args[0][0] - assert "schema.target_table" in actual_query + assert "hive.default.target_table" in actual_query def test_column_order_matches_first_row_keys(self, mock_connect, connector): mock_cursor = MagicMock() @@ -178,10 +182,19 @@ def test_column_order_matches_first_row_keys(self, mock_connect, connector): connector.bulk_insert("t", data) actual_query = mock_cursor.executemany.call_args[0][0] - # columns in query should match key order of first dict for col in ["z", "a", "m"]: assert col in actual_query + def test_placeholder_count_matches_column_count(self, mock_connect, connector): + mock_cursor = MagicMock() + mock_connect.cursor.return_value.__enter__.return_value = mock_cursor + + data = [{"a": 1, "b": 2, "c": 3, "d": 4}] + connector.bulk_insert("t", data) + + actual_query = mock_cursor.executemany.call_args[0][0] + assert actual_query.count("?") == 4 + def test_propagates_executemany_exception(self, mock_connect, connector): mock_cursor = MagicMock() mock_cursor.executemany.side_effect = RuntimeError("DB write error")