Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions dev_env/trino/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,24 @@

env_config: Dict[str, Any] = combine_env_configs()


client = TrinoConnector(
host=env_config["TRINO_HOST"],
port=env_config["TRINO_PORT"],
user=env_config["TRINO_USER"]
)

# Find
print(client.query("SHOW CATALOGS"))
print(client.query("SELECT * FROM pg.public.employees"))
mylist = [{'name':'Pat','department':'Facilities MGMT'}]
print(client.bulk_insert('pg.public.employees',mylist))


mylist = [{'name':'Pat','department':'Facilities MGMT'}]
client_https = TrinoConnector(
host=env_config["TRINO_HOST_2"],
port=env_config["TRINO_PORT_2"],
user=env_config["TRINO_USER_2"],
http_scheme='https'
)

print(client.bulk_insert('pg.public.employees',mylist))
print(client_https.query("SHOW CATALOGS"))
8 changes: 2 additions & 6 deletions src/pyapiary/dbms_connectors/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@
from typing import List, Dict, Any

class TrinoConnector:
def __init__(self, host, port, user, catalog=None, schema=None):
def __init__(self, **kwargs):
self.conn = connect(
host=host,
port=port,
user=user,
catalog=catalog,
schema=schema
**kwargs
)

def query(self, query_str):
Expand Down
79 changes: 46 additions & 33 deletions src/pyapiary/tests/test_trino/test_unit_trino.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from unittest.mock import MagicMock, patch, call
from unittest.mock import MagicMock, patch
from pyapiary.dbms_connectors.trino import TrinoConnector


Expand All @@ -15,11 +15,12 @@ def mock_connect(mocker):
def connector(mock_connect):
"""Return a TrinoConnector backed by a mocked connection."""
return TrinoConnector(
host="localhost",
port=8080,
host="trino.trashcollector.dev",
port=443,
user="test_user",
catalog="hive",
schema="default",
http_scheme="https",
)


Expand All @@ -28,31 +29,33 @@ def connector(mock_connect):
# ---------------------------------------------------------------------------

class TestInit:
def test_connect_called_with_correct_args(self, mocker):
def test_connect_called_with_provided_kwargs(self, mocker):
mock_connect = mocker.patch("pyapiary.dbms_connectors.trino.connect")
TrinoConnector(host="myhost", port=9090, user="alice", catalog="iceberg", schema="raw")
TrinoConnector(host="myhost", port=443, user="alice", http_scheme="https")
mock_connect.assert_called_once_with(
host="myhost",
port=9090,
port=443,
user="alice",
catalog="iceberg",
schema="raw",
http_scheme="https",
)

def test_connect_called_without_optional_args(self, mocker):
def test_connect_called_with_minimal_kwargs(self, mocker):
mock_connect = mocker.patch("pyapiary.dbms_connectors.trino.connect")
TrinoConnector(host="myhost", port=9090, user="alice")
mock_connect.assert_called_once_with(
host="myhost",
port=9090,
user="alice",
catalog=None,
schema=None,
)
TrinoConnector(host="myhost", port=8080, user="alice")
mock_connect.assert_called_once_with(host="myhost", port=8080, user="alice")

def test_conn_attribute_set(self, mock_connect, connector):
assert connector.conn is mock_connect

def test_arbitrary_kwargs_forwarded(self, mocker):
"""Any kwarg the trino client supports should be forwarded as-is."""
mock_connect = mocker.patch("pyapiary.dbms_connectors.trino.connect")
TrinoConnector(host="h", port=443, user="u", http_scheme="https",
verify=False, session_properties={"query_max_run_time": "1h"})
_, call_kwargs = mock_connect.call_args
assert call_kwargs["verify"] is False
assert call_kwargs["session_properties"] == {"query_max_run_time": "1h"}


# ---------------------------------------------------------------------------
# query()
Expand Down Expand Up @@ -108,7 +111,17 @@ def test_propagates_execute_exception(self, mock_connect, connector):
mock_connect.cursor.return_value.__enter__.return_value = mock_cursor

with pytest.raises(RuntimeError, match="syntax error"):
connector.query("SELECT bad syntax %%")
connector.query("SELECT bad %%")

def test_show_catalogs(self, mock_connect, connector):
mock_cursor = MagicMock()
mock_cursor.description = [("Catalog",)]
mock_cursor.fetchall.return_value = [("hive",), ("iceberg",), ("tpch",)]
mock_connect.cursor.return_value.__enter__.return_value = mock_cursor

result = connector.query("SHOW CATALOGS")

assert result == [("hive",), ("iceberg",), ("tpch",)]


# ---------------------------------------------------------------------------
Expand All @@ -123,27 +136,18 @@ def test_inserts_single_row(self, mock_connect, connector):
result = connector.bulk_insert("my_table", [{"id": 1, "name": "alice"}])

expected_query = "INSERT INTO my_table (id, name) VALUES (?, ?)"
mock_cursor.executemany.assert_called_once_with(
expected_query, [(1, "alice")]
)
mock_cursor.executemany.assert_called_once_with(expected_query, [(1, "alice")])
assert result is True

def test_inserts_multiple_rows(self, mock_connect, connector):
mock_cursor = MagicMock()
mock_connect.cursor.return_value.__enter__.return_value = mock_cursor

data = [
{"id": 1, "val": "a"},
{"id": 2, "val": "b"},
{"id": 3, "val": "c"},
]
data = [{"id": 1, "val": "a"}, {"id": 2, "val": "b"}, {"id": 3, "val": "c"}]
result = connector.bulk_insert("my_table", data)

expected_values = [(1, "a"), (2, "b"), (3, "c")]
_, call_values = mock_cursor.executemany.call_args
# positional args
actual_values = mock_cursor.executemany.call_args[0][1]
assert actual_values == expected_values
assert actual_values == [(1, "a"), (2, "b"), (3, "c")]
assert result is True

def test_returns_none_for_empty_list(self, mock_connect, connector, capsys):
Expand All @@ -165,10 +169,10 @@ def test_query_string_uses_correct_table_name(self, mock_connect, connector):
mock_cursor = MagicMock()
mock_connect.cursor.return_value.__enter__.return_value = mock_cursor

connector.bulk_insert("schema.target_table", [{"x": 99}])
connector.bulk_insert("hive.default.target_table", [{"x": 99}])

actual_query = mock_cursor.executemany.call_args[0][0]
assert "schema.target_table" in actual_query
assert "hive.default.target_table" in actual_query

def test_column_order_matches_first_row_keys(self, mock_connect, connector):
mock_cursor = MagicMock()
Expand All @@ -178,10 +182,19 @@ def test_column_order_matches_first_row_keys(self, mock_connect, connector):
connector.bulk_insert("t", data)

actual_query = mock_cursor.executemany.call_args[0][0]
# columns in query should match key order of first dict
for col in ["z", "a", "m"]:
assert col in actual_query

def test_placeholder_count_matches_column_count(self, mock_connect, connector):
mock_cursor = MagicMock()
mock_connect.cursor.return_value.__enter__.return_value = mock_cursor

data = [{"a": 1, "b": 2, "c": 3, "d": 4}]
connector.bulk_insert("t", data)

actual_query = mock_cursor.executemany.call_args[0][0]
assert actual_query.count("?") == 4

def test_propagates_executemany_exception(self, mock_connect, connector):
mock_cursor = MagicMock()
mock_cursor.executemany.side_effect = RuntimeError("DB write error")
Expand Down
Loading