diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d96ba51 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,66 @@ +name: CI + +on: + pull_request: + branches: + - main + +jobs: + test: + runs-on: ubuntu-22.04 + timeout-minutes: 40 + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Bun + uses: oven-sh/setup-bun@v2 + with: + bun-version: latest + + - name: Cache Bun dependencies + uses: actions/cache@v4 + with: + path: | + node_modules + ~/.bun/install/cache + key: ${{ runner.os }}-bun-${{ hashFiles('bun.lock') }} + restore-keys: | + ${{ runner.os }}-bun- + + - name: Install frontend dependencies + run: bun install --frozen-lockfile + + - name: Setup Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache Rust build artifacts + uses: Swatinem/rust-cache@v2 + with: + workspaces: src-tauri -> src-tauri/target + + - name: Run unit tests + run: bun run test:unit + + - name: Run service tests + run: bun run test:service + + - name: Run rust unit tests + run: bun run test:rust:unit + + - name: Run integration tests (MySQL + Postgres with testcontainers) + run: IT_DB=all bun run test:integration + + - name: Docker diagnostics on failure + if: failure() + run: | + echo "==== docker ps -a ====" + docker ps -a || true + echo "==== recent mysql/postgres logs ====" + for image in mysql:8.0 postgres:16-alpine; do + for id in $(docker ps -aq --filter "ancestor=${image}"); do + echo "--- logs for $id (${image}) ---" + docker logs "$id" || true + done + done diff --git a/.gitignore b/.gitignore index c9e77cf..bbf8806 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,9 @@ reference # plan .trae/documents/ .cursor/plans/ + +# skills +.trae/skills/* + +# example +githubworkflowexample/* diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index 382d0aa..0000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,86 +0,0 @@ -## [unreleased] - -### 🚀 Features - -- *(DataGrid)* Add right-click copy function to table view -- *(theme)* Apply a warm theme style to the dark mode editor -- *(SqlEditor)* Supports executing selected SQL fragments and optimizing save logic -## [0.1.1] - 2026-02-23 - -### 🚀 Features - -- 添加应用菜单和设置事件监听 - -### 🐛 Bug Fixes - -- *(ssh)* 验证SSH和目标端口范围并添加单元测试 - -### 📚 Documentation - -- 更新 README 以反映项目重命名为 DbPaw - -### ⚙️ Miscellaneous Tasks - -- Update version to 0.1.1 -## [0.1.0] - 2026-02-23 - -### 🚀 Features - -- 初始化 Tauri + React 桌面数据库管理应用 -- *(ui)* 更新应用图标并实现从后端获取连接列表 -- *(ui)* 引入 shadcn/ui 组件库并重构服务层 -- *(侧边栏)* 支持通过连接ID获取数据库列表和表列表 -- 增强数据库表数据浏览功能并改进类型处理 -- *(TableView)* 添加列宽调整和右键菜单功能 -- Delect-side-icon -- *(Sidebar)* 为数据库侧边栏添加数据库右键菜单 -- 实现多标签查询编辑器和数据库切换功能 -- *(ui)* 为 ResizableHandle 组件添加可选的视觉手柄 -- *(metadata)* 新增查看表结构DDL功能 -- 实现SQL编辑器受控状态和DDL缓存刷新 -- 新增主题系统与持久化设置支持 -- *(sql编辑器)* 替换 Monaco 为 CodeMirror 并添加 SQL 自动补全和格式化功能 -- *(SqlEditor)* 添加全局列和表自动补全支持 -- *(DataGrid)* Add sorting functionality to table view -- *(App)* Add TableMetadataView and DDL handling -- *(TableView)* Enhance cell editing and data refresh functionality -- *(DatabaseSidebar)* Enhance table handling and metadata fetching -- *(App, TableView, API)* Implement filtering and ordering functionality -- *(ui)* 为标签页添加右键菜单并改进占位符 -- *(数据库连接)* 为 PostgreSQL 和 MySQL 驱动程序添加 SSL 支持 -- 将界面文本从中文切换为英文 -- *(mock)* Add mock data support for independent frontend development -- 新增保存和加载查询功能 -- 添加全局快捷键和编辑器保存快捷键 -- 为保存的查询添加数据库字段支持并改进错误处理 -- *(mocks)* 添加已保存查询的模拟数据和服务 -- 支持在保存查询时编辑描述 -- 集成自定义标题栏并改进数据库迁移逻辑 -- 更新应用图标并优化侧边栏UI样式 -- *(ssh)* 支持通过 SSH 隧道连接数据库 - -### 🐛 Bug Fixes - -- 修复主题加载闪烁并改进错误处理 -- *(DataGrid)* 修复列宽调整时起始宽度计算不准确的问题 -- *(db/sql编辑器)* 修复数据库驱动列索引错误并增强调试日志 -- 移除调试日志并改进数据库模式获取的错误处理 -- Resolve unused variables in mocks.ts for CI build -- Resolve ubuntu dependency conflict - -### 🚜 Refactor - -- *(services)* 将 mock 数据中的 driver 字段统一重命名为 dbType -- 移除数据库侧边栏组件并优化标签激活逻辑 - -### 🎨 Styling - -- *(ui)* 优化侧边栏和标签页的交互样式 -- *(ui)* 为标签页触发器添加右侧边框并调整最后一个项 -- *(DataGrid)* 移除表格最小宽度限制并添加刷新图标 -- *(ui)* 调整侧边栏及标签页的内边距和高度以优化视觉一致性 - -### ⚙️ Miscellaneous Tasks - -- 重命名项目为DbPaw并添加连接名称字段 -- 添加 GitHub Actions 发布工作流 diff --git a/README.md b/README.md index 506fe62..307737a 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ English | [简体中文](README_CN.md) | [日本語](README_JA.md) - Connect to PostgreSQL, MySQL, MariaDB (MySQL-compatible), TiDB (MySQL-compatible), SQLite, SQL Server, and ClickHouse (preview, currently read-only) - Write and run SQL with syntax highlighting, auto-completion, and one-click formatting - Browse query results in a data grid with filtering, sorting, pagination, and export +- Import `.sql` files into MySQL/MariaDB/TiDB/PostgreSQL/SQLite/DuckDB/SQL Server with all-or-nothing rollback - Save and reuse frequently used SQL scripts with Saved Queries - Use the AI sidebar to draft SQL and explain queries (optional) - Access remote databases through SSH tunneling diff --git a/README_CN.md b/README_CN.md index f7e4bc4..029bd96 100644 --- a/README_CN.md +++ b/README_CN.md @@ -36,6 +36,7 @@ - 连接 PostgreSQL、MySQL、MariaDB(MySQL 兼容)、TiDB(MySQL 兼容)、SQLite、SQL Server 与 ClickHouse(预览版,当前只读) - 编写与执行 SQL:语法高亮、自动补全、一键格式化 - 在数据网格中浏览结果,支持过滤、排序、分页与导出 +- 支持将 `.sql` 文件导入 MySQL/MariaDB/TiDB/PostgreSQL/SQLite/DuckDB/SQL Server,并在失败时全量回滚 - 使用 Saved Queries 保存并复用常用 SQL 脚本 - 使用 AI 侧边栏辅助写 SQL、解释查询(可选) - 通过 SSH 隧道访问远程数据库 diff --git a/docs/zh/Development/DEVELOPMENT.md b/docs/zh/Development/DEVELOPMENT.md index 0cdeaa0..196f1d8 100644 --- a/docs/zh/Development/DEVELOPMENT.md +++ b/docs/zh/Development/DEVELOPMENT.md @@ -51,6 +51,103 @@ bun run test:rust:unit bun run test:integration ``` +### 集成测试自动化(MySQL + MariaDB + Postgres + ClickHouse + SQL Server + DuckDB) + +- 默认执行 `bun run test:integration` 会自动启动/销毁 MySQL、MariaDB、Postgres、ClickHouse 与 SQL Server 容器(DuckDB 使用本地临时文件,不依赖容器)。 +- 可通过 `IT_DB` 指定目标数据库: + ```bash + IT_DB=mysql bun run test:integration + IT_DB=mariadb bun run test:integration + IT_DB=postgres bun run test:integration + IT_DB=clickhouse bun run test:integration + IT_DB=mssql bun run test:integration + IT_DB=duckdb bun run test:integration + IT_DB=all bun run test:integration + ``` +- 如需复用你本地已经启动的数据库(兼容旧流程),可设置: + ```bash + IT_REUSE_LOCAL_DB=1 bun run test:integration + ``` + +### 集成测试常见环境变量(可选覆盖) + +- MySQL: `MYSQL_HOST` `MYSQL_PORT` `MYSQL_USER` `MYSQL_PASSWORD` `MYSQL_DB` +- MariaDB: `MARIADB_HOST` `MARIADB_PORT` `MARIADB_USER` `MARIADB_PASSWORD` `MARIADB_DB` +- Postgres: `POSTGRES_HOST` `POSTGRES_PORT` `POSTGRES_USER` `POSTGRES_PASSWORD` `POSTGRES_DB` +- ClickHouse: `CLICKHOUSE_HOST` `CLICKHOUSE_PORT` `CLICKHOUSE_USER` `CLICKHOUSE_PASSWORD` `CLICKHOUSE_DB` +- SQL Server: `MSSQL_HOST` `MSSQL_PORT` `MSSQL_USER` `MSSQL_PASSWORD` `MSSQL_DB` +- DuckDB: `DUCKDB_IT_DB_PATH`(可选)`DUCKDB_DB_PATH`(可选) +- 兼容 Postgres 常见别名: `PG_HOST` `PG_PORT` `PGUSER` `PGPASSWORD` `PGDATABASE` + +### 排障建议 + +- 镜像拉取慢:先手动执行 `docker pull mysql:8.0`、`docker pull mariadb:11`、`docker pull postgres:16-alpine`、`docker pull clickhouse/clickhouse-server:24.3` 和 `docker pull mcr.microsoft.com/mssql/server:2022-latest` 预热。 +- 端口冲突:集成测试默认使用 Docker 动态映射端口,通常不会冲突;如本地复用模式冲突,请调整 `*_PORT`。 +- Apple 芯片兼容:若首次拉取较慢,建议预先拉取镜像并等待 Docker Desktop 完成架构层初始化。 + +### 推荐工作流 + +- 日常开发:优先执行 `test:unit` + `test:service`。 +- 提交前:按需执行 `test:integration` 做数据库回归。 +- PR:CI 会固定执行集成测试作为质量兜底。 + +### 功能开发后怎么跑测试(实践版) + +1. 开发过程中(高频、快速反馈) + +- 先跑: + ```bash + bun run test:unit + bun run test:service + ``` +- 适用:前端逻辑、业务逻辑、小范围改动的快速验证。 + +2. 改动涉及数据库行为时(中频) + +- 跑: + ```bash + IT_DB=all bun run test:integration + ``` +- 或按需只跑单库: + ```bash + IT_DB=mysql bun run test:integration + IT_DB=mariadb bun run test:integration + IT_DB=postgres bun run test:integration + IT_DB=clickhouse bun run test:integration + IT_DB=mssql bun run test:integration + IT_DB=duckdb bun run test:integration + ``` +- 适用:连接参数、驱动逻辑、执行 SQL、表/库元数据、DDL/DML、类型映射相关改动。 + +3. 提交前(低频但建议) + +- 至少跑一次: + ```bash + IT_DB=all bun run test:integration + ``` +- PR 流水线会再次自动跑,作为最终兜底。 + +### 这套集成测试覆盖什么 / 不覆盖什么 + +- 能覆盖: + - Rust 数据库层真实连库能力 + - 常见数据库操作链路(连接、建表、查询、元数据、DDL) + - 驱动兼容与类型映射问题 +- 不覆盖: + - 前端 UI 的“点点点”交互流程(这属于 E2E/UI 自动化范畴) + - 纯视觉样式问题 + +### 什么时候可以不跑集成测试 + +- 仅改文案、样式、纯前端展示层,且不影响数据库交互。 +- 仅改与数据库完全无关的代码。 +- 快速迭代中间版本可不跑;合并前建议补跑一次。 + +### 容器清理说明 + +- 默认模式(未设置 `IT_REUSE_LOCAL_DB=1`)下,测试使用 testcontainers 拉起临时容器,测试结束后会自动销毁。 +- 设置 `IT_REUSE_LOCAL_DB=1` 时,测试会连接你手动准备的数据库实例,不会自动删除你自己的容器。 + ## 代码格式化 ```bash diff --git a/package.json b/package.json index f63bfee..955b1b8 100644 --- a/package.json +++ b/package.json @@ -5,7 +5,7 @@ "url": "git+https://github.com/codeErrorSleep/dbpaw.git" }, "private": true, - "version": "0.2.9", + "version": "0.3.0", "type": "module", "scripts": { "dev": "vite", diff --git a/scripts/db-onboard.sh b/scripts/db-onboard.sh new file mode 100644 index 0000000..bff50e7 --- /dev/null +++ b/scripts/db-onboard.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash +set -euo pipefail + +if [[ $# -lt 1 ]]; then + echo "Usage: scripts/db-onboard.sh [--skip-gate] [--skip-matrix]" + exit 1 +fi + +db="$1" +shift || true + +skip_gate=0 +skip_matrix=0 + +for arg in "$@"; do + case "$arg" in + --skip-gate) + skip_gate=1 + ;; + --skip-matrix) + skip_matrix=1 + ;; + *) + echo "[error] unknown option: $arg" + echo "Usage: scripts/db-onboard.sh [--skip-gate] [--skip-matrix]" + exit 1 + ;; + esac +done + +root_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${root_dir}" + +context_file="src-tauri/tests/common/${db}_context.rs" +integration_file="src-tauri/tests/${db}_integration.rs" +command_file="src-tauri/tests/${db}_command_integration.rs" +stateful_file="src-tauri/tests/${db}_stateful_command_integration.rs" +tracker_file="docs/zh/Development/MYSQL_TEST_COVERAGE_GAP_TRACKER.md" + +echo "[step] scaffold check: ${db}" +missing=0 +for file in "${context_file}" "${integration_file}" "${command_file}" "${stateful_file}"; do + if [[ ! -f "${file}" ]]; then + echo "[missing] ${file}" + missing=1 + else + echo "[ok] ${file}" + fi +done + +if [[ ${missing} -ne 0 ]]; then + echo "[error] scaffold is incomplete for '${db}'." + echo "[hint] finish scaffold first, then rerun scripts/db-onboard.sh ${db}" + exit 1 +fi + +if [[ ${skip_gate} -eq 0 ]]; then + echo "[step] gate syntax check" + bash -n scripts/test-integration.sh + + echo "[step] compile smoke: ${db}_integration" + cargo test --manifest-path src-tauri/Cargo.toml --test "${db}_integration" --no-run + + echo "[step] compile smoke: ${db}_command_integration" + cargo test --manifest-path src-tauri/Cargo.toml --test "${db}_command_integration" --no-run + + echo "[step] compile smoke: ${db}_stateful_command_integration" + cargo test --manifest-path src-tauri/Cargo.toml --test "${db}_stateful_command_integration" --no-run + + echo "[step] integration gate run: IT_DB=${db}" + IT_DB="${db}" bash scripts/test-integration.sh +else + echo "[skip] gate run skipped by --skip-gate" +fi + +if [[ ${skip_matrix} -eq 0 ]]; then + echo "[step] matrix sync check" + test_count="$(rg -n "async fn test_${db}_" src-tauri/tests --glob "*.rs" || true)" + test_count="$(printf "%s\n" "${test_count}" | sed '/^$/d' | wc -l | tr -d ' ')" + echo "[info] detected test functions for ${db}: ${test_count}" + if [[ -f "${tracker_file}" ]]; then + tracker_hits="$(rg -n "test_${db}_" "${tracker_file}" || true)" + tracker_hits="$(printf "%s\n" "${tracker_hits}" | sed '/^$/d' | wc -l | tr -d ' ')" + if [[ "${tracker_hits}" -eq 0 ]]; then + echo "[warn] tracker has no '${db}' test entries yet: ${tracker_file}" + echo "[next] sync capability matrix and command coverage sections for '${db}'" + else + echo "[ok] tracker already contains ${tracker_hits} '${db}' test entries" + fi + else + echo "[warn] tracker file not found: ${tracker_file}" + fi +else + echo "[skip] matrix sync check skipped by --skip-matrix" +fi + +echo "[done] db onboarding pipeline finished for '${db}'" diff --git a/scripts/test-integration.sh b/scripts/test-integration.sh index d5dd17f..c52aab3 100755 --- a/scripts/test-integration.sh +++ b/scripts/test-integration.sh @@ -1,26 +1,79 @@ #!/usr/bin/env bash set -euo pipefail -run_integration_test() { - local db_name="$1" - local test_name="$2" - local run_flag="$3" +it_db="${IT_DB:-all}" +it_reuse_local_db="${IT_REUSE_LOCAL_DB:-0}" +it_container_prefix="${IT_CONTAINER_PREFIX:-dbpaw-it-$$-}" +export IT_CONTAINER_PREFIX="${it_container_prefix}" - if [[ "${run_flag}" != "1" ]]; then - echo "[skip] ${db_name} integration test (set ${db_name}=1 to enable)" +cleanup_it_containers() { + if [[ "${it_reuse_local_db}" == "1" ]]; then + return 0 + fi + if ! command -v docker >/dev/null 2>&1; then return 0 fi - echo "[run] ${db_name} integration test: ${test_name}" + local ids + ids="$(docker ps -aq --filter "name=${it_container_prefix}" || true)" + if [[ -n "${ids}" ]]; then + echo "[cleanup] removing leftover integration containers: ${it_container_prefix}*" + docker rm -f ${ids} >/dev/null 2>&1 || true + fi +} + +cleanup_it_containers +trap cleanup_it_containers EXIT + +run_integration_test() { + local test_name="$1" + echo "[run] integration test: ${test_name} (IT_REUSE_LOCAL_DB=${it_reuse_local_db})" cargo test \ --manifest-path src-tauri/Cargo.toml \ - --test "${test_name}" \ - -- --ignored --nocapture + --test "${test_name}" -- --ignored --nocapture --test-threads=1 } -run_integration_test "RUN_MYSQL_IT" "mysql_integration" "${RUN_MYSQL_IT:-0}" -run_integration_test "RUN_MARIADB_IT" "mariadb_integration" "${RUN_MARIADB_IT:-0}" -run_integration_test "RUN_POSTGRES_IT" "postgres_integration" "${RUN_POSTGRES_IT:-0}" -run_integration_test "RUN_SQLITE_IT" "sqlite_integration" "${RUN_SQLITE_IT:-0}" -run_integration_test "RUN_MSSQL_IT" "mssql_integration" "${RUN_MSSQL_IT:-0}" -run_integration_test "RUN_CLICKHOUSE_IT" "clickhouse_integration" "${RUN_CLICKHOUSE_IT:-0}" +case "${it_db}" in + mysql) + run_integration_test "mysql_integration" + run_integration_test "mysql_command_integration" + run_integration_test "mysql_stateful_command_integration" + ;; + mariadb) + run_integration_test "mariadb_integration" + ;; + postgres) + run_integration_test "postgres_integration" + run_integration_test "postgres_command_integration" + run_integration_test "postgres_stateful_command_integration" + ;; + clickhouse) + run_integration_test "clickhouse_integration" + ;; + mssql) + run_integration_test "mssql_integration" + ;; + duckdb) + run_integration_test "duckdb_integration" + ;; + sqlite) + run_integration_test "sqlite_integration" + ;; + all) + run_integration_test "mysql_integration" + run_integration_test "mysql_command_integration" + run_integration_test "mysql_stateful_command_integration" + run_integration_test "mariadb_integration" + run_integration_test "postgres_integration" + run_integration_test "postgres_command_integration" + run_integration_test "postgres_stateful_command_integration" + run_integration_test "clickhouse_integration" + run_integration_test "mssql_integration" + run_integration_test "duckdb_integration" + run_integration_test "sqlite_integration" + ;; + *) + echo "[error] Invalid IT_DB='${it_db}'. Expected one of: mysql|mariadb|postgres|clickhouse|mssql|duckdb|sqlite|all" + exit 1 + ;; +esac diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 02d4dd4..4799250 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -504,6 +504,18 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" +[[package]] +name = "bb8" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89aabfae550a5c44b43ab941844ffcd2e993cb6900b342debf59e9ea74acdb8" +dependencies = [ + "async-trait", + "futures-util", + "parking_lot", + "tokio", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -562,6 +574,16 @@ dependencies = [ "piper", ] +[[package]] +name = "bollard-stubs" +version = "1.42.0-rc.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed59b5c00048f48d7af971b71f800fdf23e858844a6f9e4d32ca72e9399e7864" +dependencies = [ + "serde", + "serde_with 1.14.0", +] + [[package]] name = "borsh" version = "1.6.0" @@ -1041,14 +1063,38 @@ dependencies = [ "cipher", ] +[[package]] +name = "darling" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a01d95850c592940db9b8194bc39f4bc0e89dee5c4265e4b1807c34a9aba453c" +dependencies = [ + "darling_core 0.13.4", + "darling_macro 0.13.4", +] + [[package]] name = "darling" version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.21.3", + "darling_macro 0.21.3", +] + +[[package]] +name = "darling_core" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "859d65a907b6852c9361e3185c862aae7fafd2887876799fa55f5f99dc40d610" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.10.0", + "syn 1.0.109", ] [[package]] @@ -1061,17 +1107,28 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "strsim", + "strsim 0.11.1", "syn 2.0.117", ] +[[package]] +name = "darling_macro" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c972679f83bdf9c42bd905396b6c3588a843a17f0f16dfcfa3e2c5d57441835" +dependencies = [ + "darling_core 0.13.4", + "quote", + "syn 1.0.109", +] + [[package]] name = "darling_macro" version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" dependencies = [ - "darling_core", + "darling_core 0.21.3", "quote", "syn 2.0.117", ] @@ -1083,6 +1140,7 @@ dependencies = [ "aes-gcm", "async-trait", "base64 0.22.1", + "bb8", "chrono", "duckdb", "futures-util", @@ -1102,6 +1160,7 @@ dependencies = [ "tauri-plugin-store", "tauri-plugin-updater", "tauri-plugin-window-state", + "testcontainers", "tiberius", "tokio", "tokio-util", @@ -1183,7 +1242,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -1386,7 +1445,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -1576,6 +1635,21 @@ dependencies = [ "new_debug_unreachable", ] +[[package]] +name = "futures" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.32" @@ -1662,6 +1736,7 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -4345,7 +4420,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -4402,7 +4477,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -4674,6 +4749,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_with" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "678b5a069e50bf00ecd22d0cd8ddf7c236f68581b03db652061ed5eb13a312ff" +dependencies = [ + "serde", + "serde_with_macros 1.5.2", +] + [[package]] name = "serde_with" version = "3.17.0" @@ -4689,17 +4774,29 @@ dependencies = [ "schemars 1.2.1", "serde_core", "serde_json", - "serde_with_macros", + "serde_with_macros 3.17.0", "time", ] +[[package]] +name = "serde_with_macros" +version = "1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e182d6ec6f05393cc0e5ed1bf81ad6db3a8feedf8ee515ecdd369809bcce8082" +dependencies = [ + "darling 0.13.4", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "serde_with_macros" version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6d4e30573c8cb306ed6ab1dca8423eec9a463ea0e155f45399455e0368b27e0" dependencies = [ - "darling", + "darling 0.21.3", "proc-macro2", "quote", "syn 2.0.117", @@ -5152,6 +5249,12 @@ dependencies = [ "unicode-properties", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "strsim" version = "0.11.1" @@ -5691,7 +5794,7 @@ dependencies = [ "serde", "serde-untagged", "serde_json", - "serde_with", + "serde_with 3.17.0", "swift-rs", "thiserror 2.0.18", "toml 0.9.12+spec-1.1.0", @@ -5722,7 +5825,7 @@ dependencies = [ "getrandom 0.4.1", "once_cell", "rustix", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -5736,6 +5839,23 @@ dependencies = [ "utf-8", ] +[[package]] +name = "testcontainers" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d2931d7f521af5bae989f716c3fa43a6af9af7ec7a5e21b59ae40878cec00" +dependencies = [ + "bollard-stubs", + "futures", + "hex", + "hmac", + "log", + "rand 0.8.5", + "serde", + "serde_json", + "sha2", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -5877,6 +5997,7 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", "socket2", "tokio-macros", @@ -6675,7 +6796,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 1c3eb1f..6575ed1 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -33,6 +33,7 @@ tokio-util = { version = "0.7", features = ["compat"] } uuid = { version = "1", features = ["v4", "serde"] } sqlx = { version = "0.8.6", default-features = false, features = ["runtime-tokio-rustls", "postgres", "mysql", "sqlite", "macros", "chrono", "json", "rust_decimal"] } tiberius = { version = "0.12", features = ["chrono", "rust_decimal"] } +bb8 = "0.8.6" async-trait = "0.1" ssh2 = { version = "0.9", features = ["vendored-openssl"] } rand = "0.8" @@ -43,3 +44,6 @@ futures-util = "0.3" aes-gcm = "0.10" base64 = "0.22" duckdb = { version = "1.2.2", features = ["bundled"] } + +[dev-dependencies] +testcontainers = "0.15.0" diff --git a/src-tauri/src/ai/prompt.rs b/src-tauri/src/ai/prompt.rs index 0211472..92d78b5 100644 --- a/src-tauri/src/ai/prompt.rs +++ b/src-tauri/src/ai/prompt.rs @@ -5,8 +5,8 @@ const MAX_TABLES: usize = 8; const MAX_COLUMNS: usize = 12; const MAX_SCHEMA_CHARS: usize = 6000; -/// Build a minimal prompt bundle without restrictive rules. -/// User input is passed directly to the AI with schema context attached. +/// Build a minimal prompt bundle with system context only. +/// User and assistant turns should come from persisted conversation history. pub fn build_prompt_bundle( _scenario: &str, input: &str, @@ -15,17 +15,22 @@ pub fn build_prompt_bundle( let selected = select_tables(input, schema_overview); let schema_text = render_schema_summary(&selected); - // Simple user message with schema context attached + let mut content = + "Use the conversation history as the source of truth for user and assistant turns." + .to_string(); + let content = if schema_text.is_empty() || schema_text == "(No schema provided)" { - input.to_string() + content } else { - format!("{}\n\nDatabase schema:\n{}", input, schema_text) + content.push_str("\n\nDatabase schema:\n"); + content.push_str(&schema_text); + content }; AiPromptBundle { prompt_version: PROMPT_VERSION.to_string(), messages: vec![AiChatMessage { - role: "user".to_string(), + role: "system".to_string(), content, }], } @@ -234,8 +239,10 @@ mod tests { let bundle = build_prompt_bundle("sql_generate", "List all users", Some(&overview)); assert_eq!(bundle.messages.len(), 1); - assert_eq!(bundle.messages[0].role, "user"); - assert!(bundle.messages[0].content.contains("List all users")); + assert_eq!(bundle.messages[0].role, "system"); + assert!(bundle.messages[0] + .content + .contains("Use the conversation history as the source of truth")); assert!(bundle.messages[0].content.contains("Database schema:")); assert!(bundle.messages[0].content.contains("public.users")); } @@ -245,7 +252,10 @@ mod tests { let bundle = build_prompt_bundle("sql_generate", "Hello", None); assert_eq!(bundle.messages.len(), 1); - assert_eq!(bundle.messages[0].role, "user"); - assert_eq!(bundle.messages[0].content, "Hello"); + assert_eq!(bundle.messages[0].role, "system"); + assert_eq!( + bundle.messages[0].content, + "Use the conversation history as the source of truth for user and assistant turns." + ); } } diff --git a/src-tauri/src/commands/ai.rs b/src-tauri/src/commands/ai.rs index d2891f5..0830916 100644 --- a/src-tauri/src/commands/ai.rs +++ b/src-tauri/src/commands/ai.rs @@ -95,6 +95,16 @@ fn map_history_load_error(conversation_id: i64, e: &str) -> String { "Failed to load conversation history".to_string() } +fn assemble_final_messages( + bundle: &[AiChatMessage], + history: &[AiChatMessage], +) -> Vec { + let mut final_messages = Vec::with_capacity(bundle.len() + history.len()); + final_messages.extend(bundle.iter().cloned()); + final_messages.extend(history.iter().cloned()); + final_messages +} + async fn get_db(state: &State<'_, AppState>) -> Result, String> { let local_db = { let lock = state.local_db.lock().await; @@ -103,6 +113,14 @@ async fn get_db(state: &State<'_, AppState>) -> Result Result, String> { + let local_db = { + let lock = state.local_db.lock().await; + lock.clone() + }; + local_db.ok_or_else(|| "Local DB not initialized".to_string()) +} + fn provider_from_model(p: crate::models::AiProvider, api_key: String) -> OpenAICompatProvider { OpenAICompatProvider { name: p.name, @@ -139,6 +157,11 @@ pub async fn ai_list_providers( db.list_ai_providers_public().await } +pub async fn ai_list_providers_direct(state: &AppState) -> Result, String> { + let db = get_db_from_app_state(state).await?; + db.list_ai_providers_public().await +} + #[tauri::command] pub async fn ai_create_provider( state: State<'_, AppState>, @@ -150,6 +173,16 @@ pub async fn ai_create_provider( db.get_ai_provider_public_by_id(created.id).await } +pub async fn ai_create_provider_direct( + state: &AppState, + mut config: AiProviderForm, +) -> Result { + normalize_provider_form(&mut config, Some("openai"))?; + let db = get_db_from_app_state(state).await?; + let created = db.create_ai_provider(config).await?; + db.get_ai_provider_public_by_id(created.id).await +} + #[tauri::command] pub async fn ai_update_provider( state: State<'_, AppState>, @@ -162,18 +195,39 @@ pub async fn ai_update_provider( db.get_ai_provider_public_by_id(updated.id).await } +pub async fn ai_update_provider_direct( + state: &AppState, + id: i64, + mut config: AiProviderForm, +) -> Result { + normalize_provider_form(&mut config, None)?; + let db = get_db_from_app_state(state).await?; + let updated = db.update_ai_provider(id, config).await?; + db.get_ai_provider_public_by_id(updated.id).await +} + #[tauri::command] pub async fn ai_delete_provider(state: State<'_, AppState>, id: i64) -> Result<(), String> { let db = get_db(&state).await?; db.delete_ai_provider(id).await } +pub async fn ai_delete_provider_direct(state: &AppState, id: i64) -> Result<(), String> { + let db = get_db_from_app_state(state).await?; + db.delete_ai_provider(id).await +} + #[tauri::command] pub async fn ai_set_default_provider(state: State<'_, AppState>, id: i64) -> Result<(), String> { let db = get_db(&state).await?; db.set_default_ai_provider(id).await } +pub async fn ai_set_default_provider_direct(state: &AppState, id: i64) -> Result<(), String> { + let db = get_db_from_app_state(state).await?; + db.set_default_ai_provider(id).await +} + #[tauri::command] pub async fn ai_clear_provider_api_key( state: State<'_, AppState>, @@ -184,6 +238,15 @@ pub async fn ai_clear_provider_api_key( db.clear_ai_provider_api_key(&provider_type).await } +pub async fn ai_clear_provider_api_key_direct( + state: &AppState, + provider_type: String, +) -> Result<(), String> { + let provider_type = normalize_provider_type(&provider_type)?; + let db = get_db_from_app_state(state).await?; + db.clear_ai_provider_api_key(&provider_type).await +} + #[tauri::command] pub async fn ai_chat_start( app: tauri::AppHandle, @@ -372,8 +435,7 @@ async fn run_chat( } } - let mut final_messages = bundle.messages.clone(); - final_messages.extend(history); + let final_messages = assemble_final_messages(&bundle.messages, &history); let _ = app.emit( "ai/started", @@ -440,6 +502,178 @@ async fn run_chat( }) } +async fn run_chat_direct( + state: &AppState, + request: AiChatRequest, + create_if_missing: bool, +) -> Result { + let db = get_db_from_app_state(state).await?; + + let provider_record = if let Some(provider_id) = request.provider_id { + db.get_ai_provider_by_id(provider_id) + .await + .map_err(|e| map_provider_lookup_error(&e))? + } else { + db.get_default_ai_provider() + .await + .map_err(|e| map_default_provider_error(&e))? + }; + + ensure_provider_enabled(provider_record.enabled)?; + validate_conversation_requirement(request.conversation_id, create_if_missing)?; + + let api_key = db + .decrypt_ai_api_key(&provider_record.api_key) + .map_err(|_| { + "AI provider apiKey is missing or invalid. Please re-save it in AI Provider settings." + .to_string() + })?; + let provider = provider_from_model(provider_record.clone(), api_key); + provider.validate_config()?; + + let conversation = match request.conversation_id { + Some(id) => db.get_ai_conversation(id).await?, + None if create_if_missing => { + let title = request + .title + .clone() + .unwrap_or_else(|| request.input.chars().take(36).collect()); + db.create_ai_conversation( + title, + request.scenario.clone(), + request.connection_id, + request.database.clone(), + ) + .await? + } + None => unreachable!("conversation requirement should be validated before this branch"), + }; + + let user_message = db + .create_ai_message( + conversation.id, + "user".to_string(), + request.input.clone(), + None, + None, + None, + None, + None, + ) + .await?; + + let mut schema_override: Option = None; + let mut selection_hint = String::new(); + if let (Some(conn_id), Some(selected)) = + (request.connection_id, request.selected_tables.as_ref()) + { + if !selected.is_empty() { + let driver = super::ensure_connection_with_db_from_app_state( + state, + conn_id, + request.database.clone(), + ) + .await?; + let mut tables: Vec = Vec::new(); + for t in selected { + let structure = driver + .get_table_structure(t.schema.clone(), t.name.clone()) + .await?; + let columns = structure + .columns + .into_iter() + .map(|c| AiColumnSummary { + name: c.name, + column_type: c.r#type, + nullable: Some(c.nullable), + }) + .collect(); + tables.push(AiTableSummary { + schema: t.schema.clone(), + name: t.name.clone(), + columns, + }); + } + selection_hint = selected + .iter() + .map(|t| t.name.as_str()) + .collect::>() + .join(" "); + schema_override = Some(AiSchemaOverview { tables }); + } + } + + let input_for_prompt = if selection_hint.is_empty() { + request.input.clone() + } else { + format!("{} {}", request.input, selection_hint) + }; + + let bundle = build_prompt_bundle( + &request.scenario, + &input_for_prompt, + schema_override + .as_ref() + .or_else(|| request.schema_overview.as_ref()), + ); + + let mut history: Vec = Vec::new(); + let mut existing = db + .list_ai_messages(conversation.id) + .await + .map_err(|e| map_history_load_error(conversation.id, &e))?; + if existing.len() > 16 { + existing = existing.split_off(existing.len() - 16); + } + for item in existing { + if item.role == "user" || item.role == "assistant" { + history.push(AiChatMessage { + role: item.role, + content: item.content, + }); + } + } + + let final_messages = assemble_final_messages(&bundle.messages, &history); + let start = std::time::Instant::now(); + let response = provider.chat_stream(final_messages, |_piece| {}).await?; + let latency_ms = start.elapsed().as_millis() as i64; + + let assistant_message = db + .create_ai_message( + conversation.id, + "assistant".to_string(), + response.content.clone(), + Some(bundle.prompt_version), + Some(response.model.clone()), + response.usage.as_ref().and_then(|u| u.prompt_tokens), + response.usage.as_ref().and_then(|u| u.completion_tokens), + Some(latency_ms), + ) + .await?; + let _ = db.touch_ai_conversation(conversation.id).await; + + Ok(AiStartResponse { + conversation_id: conversation.id, + user_message_id: user_message.id, + assistant_message_id: assistant_message.id, + }) +} + +pub async fn ai_chat_start_direct( + state: &AppState, + request: AiChatRequest, +) -> Result { + run_chat_direct(state, request, true).await +} + +pub async fn ai_chat_continue_direct( + state: &AppState, + request: AiChatRequest, +) -> Result { + run_chat_direct(state, request, false).await +} + #[tauri::command] pub async fn ai_list_conversations( state: State<'_, AppState>, @@ -450,6 +684,15 @@ pub async fn ai_list_conversations( db.list_ai_conversations(connection_id, database).await } +pub async fn ai_list_conversations_direct( + state: &AppState, + connection_id: Option, + database: Option, +) -> Result, String> { + let db = get_db_from_app_state(state).await?; + db.list_ai_conversations(connection_id, database).await +} + #[tauri::command] pub async fn ai_get_conversation( state: State<'_, AppState>, @@ -464,6 +707,19 @@ pub async fn ai_get_conversation( }) } +pub async fn ai_get_conversation_direct( + state: &AppState, + conversation_id: i64, +) -> Result { + let db = get_db_from_app_state(state).await?; + let conversation = db.get_ai_conversation(conversation_id).await?; + let messages = db.list_ai_messages(conversation_id).await?; + Ok(AiConversationDetail { + conversation, + messages, + }) +} + #[tauri::command] pub async fn ai_delete_conversation( state: State<'_, AppState>, @@ -473,12 +729,22 @@ pub async fn ai_delete_conversation( db.delete_ai_conversation(conversation_id).await } +pub async fn ai_delete_conversation_direct( + state: &AppState, + conversation_id: i64, +) -> Result<(), String> { + let db = get_db_from_app_state(state).await?; + db.delete_ai_conversation(conversation_id).await +} + #[cfg(test)] mod tests { use super::{ - ensure_provider_enabled, map_default_provider_error, map_history_load_error, - map_provider_lookup_error, normalize_provider_type, validate_conversation_requirement, + assemble_final_messages, ensure_provider_enabled, map_default_provider_error, + map_history_load_error, map_provider_lookup_error, normalize_provider_type, + validate_conversation_requirement, }; + use crate::ai::types::AiChatMessage; #[test] fn normalize_provider_type_rejects_empty_value() { @@ -551,4 +817,34 @@ mod tests { "Failed to load conversation history" ); } + + #[test] + fn assemble_final_messages_keeps_context_before_history() { + let bundle = vec![AiChatMessage { + role: "system".to_string(), + content: "schema".to_string(), + }]; + let history = vec![ + AiChatMessage { + role: "user".to_string(), + content: "older question".to_string(), + }, + AiChatMessage { + role: "assistant".to_string(), + content: "older answer".to_string(), + }, + AiChatMessage { + role: "user".to_string(), + content: "latest question".to_string(), + }, + ]; + + let final_messages = assemble_final_messages(&bundle, &history); + + assert_eq!(final_messages.len(), 4); + assert_eq!(final_messages[0].role, "system"); + assert_eq!(final_messages[1].content, "older question"); + assert_eq!(final_messages[2].content, "older answer"); + assert_eq!(final_messages[3].content, "latest question"); + } } diff --git a/src-tauri/src/commands/connection.rs b/src-tauri/src/commands/connection.rs index 91de5ba..452f34e 100644 --- a/src-tauri/src/commands/connection.rs +++ b/src-tauri/src/commands/connection.rs @@ -58,6 +58,10 @@ fn quote_mysql_ident(ident: &str) -> String { format!("`{}`", ident.replace('`', "``")) } +fn quote_clickhouse_ident(ident: &str) -> String { + format!("`{}`", ident.replace('`', "``")) +} + fn quote_pg_ident(ident: &str) -> String { format!("\"{}\"", ident.replace('"', "\"\"")) } @@ -144,6 +148,49 @@ fn build_mssql_create_database_sql( Ok(create_sql) } +fn build_clickhouse_create_database_sql( + payload: &CreateDatabasePayload, + db_name: &str, +) -> Result { + if let Some(v) = normalize_option_token(&payload.charset, "charset")? { + return Err(format!( + "[VALIDATION_ERROR] ClickHouse create database does not support charset option: {}", + v + )); + } + if let Some(v) = normalize_option_token(&payload.collation, "collation")? { + return Err(format!( + "[VALIDATION_ERROR] ClickHouse create database does not support collation option: {}", + v + )); + } + if let Some(v) = normalize_option_token(&payload.encoding, "encoding")? { + return Err(format!( + "[VALIDATION_ERROR] ClickHouse create database does not support encoding option: {}", + v + )); + } + if let Some(v) = normalize_option_token(&payload.lc_collate, "lc_collate")? { + return Err(format!( + "[VALIDATION_ERROR] ClickHouse create database does not support lc_collate option: {}", + v + )); + } + if let Some(v) = normalize_option_token(&payload.lc_ctype, "lc_ctype")? { + return Err(format!( + "[VALIDATION_ERROR] ClickHouse create database does not support lc_ctype option: {}", + v + )); + } + + let mut sql = String::from("CREATE DATABASE "); + if payload.if_not_exists.unwrap_or(true) { + sql.push_str("IF NOT EXISTS "); + } + sql.push_str("e_clickhouse_ident(db_name)); + Ok(sql) +} + fn normalize_create_database_error(err: String, db_name: &str) -> String { let lower = err.to_lowercase(); if lower.contains("already exists") @@ -185,6 +232,13 @@ pub async fn list_databases_by_id( .await } +pub async fn list_databases_by_id_direct(state: &AppState, id: i64) -> Result, String> { + super::execute_with_retry_from_app_state(state, id, None, |driver| async move { + driver.list_databases().await + }) + .await +} + #[tauri::command] pub async fn create_database_by_id( state: State<'_, AppState>, @@ -205,7 +259,7 @@ pub async fn create_database_by_id( .to_lowercase() }; - if matches!(driver.as_str(), "sqlite" | "duckdb" | "clickhouse") { + if matches!(driver.as_str(), "sqlite" | "duckdb") { return Err(format!( "[UNSUPPORTED] Driver {} does not support creating databases in this flow", driver @@ -250,6 +304,95 @@ pub async fn create_database_by_id( }) .await } + "clickhouse" => { + let sql = build_clickhouse_create_database_sql(&payload, &db_name)?; + super::execute_with_retry(&state, id, None, |driver| { + let sql_clone = sql.clone(); + async move { driver.execute_query(sql_clone).await.map(|_| ()) } + }) + .await + } + _ => Err(format!( + "[UNSUPPORTED] Driver {} not supported for create database", + driver + )), + }; + + exec_res.map_err(|e| normalize_create_database_error(e, &db_name)) +} + +pub async fn create_database_by_id_direct( + state: &AppState, + id: i64, + payload: CreateDatabasePayload, +) -> Result<(), String> { + let db_name = validate_database_name(&payload.name)?; + let if_not_exists = payload.if_not_exists.unwrap_or(true); + let driver = { + let local_db = { + let lock = state.local_db.lock().await; + lock.clone() + }; + let db = local_db.ok_or("Local DB not initialized".to_string())?; + db.get_connection_form_by_id(id) + .await? + .driver + .to_lowercase() + }; + + if matches!(driver.as_str(), "sqlite" | "duckdb") { + return Err(format!( + "[UNSUPPORTED] Driver {} does not support creating databases in this flow", + driver + )); + } + + let exec_res = match driver.as_str() { + "mysql" | "mariadb" | "tidb" => { + let sql = build_mysql_create_database_sql(&payload, &db_name)?; + super::execute_with_retry_from_app_state(state, id, None, |driver| { + let sql_clone = sql.clone(); + async move { driver.execute_query(sql_clone).await.map(|_| ()) } + }) + .await + } + "postgres" => { + let create_sql = build_postgres_create_database_sql(&payload, &db_name)?; + let exists_check_sql = format!( + "SELECT 1 FROM pg_database WHERE datname = {} LIMIT 1", + quote_literal(&db_name) + ); + super::execute_with_retry_from_app_state(state, id, None, |driver| { + let exists_sql = exists_check_sql.clone(); + let create_sql = create_sql.clone(); + async move { + if if_not_exists { + let exists_result = driver.execute_query(exists_sql).await?; + if exists_result.row_count > 0 || !exists_result.data.is_empty() { + return Ok(()); + } + } + driver.execute_query(create_sql).await.map(|_| ()) + } + }) + .await + } + "mssql" => { + let sql = build_mssql_create_database_sql(&payload, &db_name)?; + super::execute_with_retry_from_app_state(state, id, None, |driver| { + let sql_clone = sql.clone(); + async move { driver.execute_query(sql_clone).await.map(|_| ()) } + }) + .await + } + "clickhouse" => { + let sql = build_clickhouse_create_database_sql(&payload, &db_name)?; + super::execute_with_retry_from_app_state(state, id, None, |driver| { + let sql_clone = sql.clone(); + async move { driver.execute_query(sql_clone).await.map(|_| ()) } + }) + .await + } _ => Err(format!( "[UNSUPPORTED] Driver {} not supported for create database", driver @@ -289,6 +432,18 @@ pub async fn get_connections(state: State<'_, AppState>) -> Result Result, String> { + let local_db = { + let lock = state.local_db.lock().await; + lock.clone() + }; + if let Some(db) = local_db { + db.list_connections().await + } else { + Err("Local DB not initialized".to_string()) + } +} + #[tauri::command] pub async fn create_connection( state: State<'_, AppState>, @@ -306,6 +461,22 @@ pub async fn create_connection( } } +pub async fn create_connection_direct( + state: &AppState, + form: ConnectionForm, +) -> Result { + let form = crate::connection_input::normalize_connection_form(form)?; + let local_db = { + let lock = state.local_db.lock().await; + lock.clone() + }; + if let Some(db) = local_db { + db.create_connection(form).await + } else { + Err("Local DB not initialized".to_string()) + } +} + #[tauri::command] pub async fn update_connection( state: State<'_, AppState>, @@ -327,6 +498,24 @@ pub async fn update_connection( } } +pub async fn update_connection_direct( + state: &AppState, + id: i64, + form: ConnectionForm, +) -> Result { + let form = crate::connection_input::normalize_connection_form(form)?; + let local_db = { + let lock = state.local_db.lock().await; + lock.clone() + }; + if let Some(db) = local_db { + state.pool_manager.remove_by_prefix(&id.to_string()).await; + db.update_connection(id, form).await + } else { + Err("Local DB not initialized".to_string()) + } +} + #[tauri::command] pub async fn delete_connection(state: State<'_, AppState>, id: i64) -> Result<(), String> { let local_db = { @@ -343,16 +532,32 @@ pub async fn delete_connection(state: State<'_, AppState>, id: i64) -> Result<() } } +pub async fn delete_connection_direct(state: &AppState, id: i64) -> Result<(), String> { + let local_db = { + let lock = state.local_db.lock().await; + lock.clone() + }; + if let Some(db) = local_db { + state.pool_manager.remove_by_prefix(&id.to_string()).await; + db.delete_connection(id).await + } else { + Err("Local DB not initialized".to_string()) + } +} + #[cfg(test)] mod tests { use super::{ - build_mssql_create_database_sql, build_mysql_create_database_sql, - build_postgres_create_database_sql, validate_database_name, CreateDatabasePayload, + build_clickhouse_create_database_sql, build_mssql_create_database_sql, + build_mysql_create_database_sql, build_postgres_create_database_sql, + validate_database_name, CreateDatabasePayload, }; use super::{ - normalize_create_database_error, normalize_option_token, quote_mssql_ident, - quote_mysql_ident, quote_pg_ident, + normalize_create_database_error, normalize_option_token, quote_clickhouse_ident, + quote_mssql_ident, quote_mysql_ident, quote_pg_ident, }; + use crate::connection_input::normalize_connection_form; + use crate::models::ConnectionForm; #[test] fn validate_database_name_rejects_empty_and_null() { @@ -393,10 +598,8 @@ mod tests { ); assert!(already.contains("[ALREADY_EXISTS]")); - let postgres = normalize_create_database_error( - "ERROR: 42P04 duplicate_database".to_string(), - "app", - ); + let postgres = + normalize_create_database_error("ERROR: 42P04 duplicate_database".to_string(), "app"); assert!(postgres.contains("[ALREADY_EXISTS]")); let perm = normalize_create_database_error( @@ -406,9 +609,29 @@ mod tests { assert!(perm.contains("[PERMISSION_DENIED]")); } + #[test] + fn mysql_ephemeral_flow_preserves_empty_password_through_normalization() { + let form = ConnectionForm { + driver: "mysql".to_string(), + host: Some(" localhost ".to_string()), + port: Some(3306), + username: Some(" root ".to_string()), + password: Some(" ".to_string()), + database: Some(" app ".to_string()), + ..Default::default() + }; + + let normalized = normalize_connection_form(form).unwrap(); + let dsn = crate::db::drivers::mysql::build_test_dsn(&normalized).unwrap(); + + assert_eq!(normalized.password, Some(String::new())); + assert_eq!(dsn, "mysql://root:@localhost:3306/app?ssl-mode=DISABLED"); + } + #[test] fn quote_idents_escape_driver_specific_characters() { assert_eq!(quote_mysql_ident("a`b"), "`a``b`"); + assert_eq!(quote_clickhouse_ident("a`b"), "`a``b`"); assert_eq!(quote_pg_ident("a\"b"), "\"a\"\"b\""); assert_eq!(quote_mssql_ident("a]b"), "[a]]b]"); } @@ -475,4 +698,40 @@ mod tests { "IF DB_ID(N'foo') IS NULL CREATE DATABASE [foo] COLLATE SQL_Latin1_General_CP1_CI_AS" ); } + + #[test] + fn clickhouse_sql_respects_if_not_exists() { + let sql = build_clickhouse_create_database_sql( + &CreateDatabasePayload { + name: "analytics".to_string(), + if_not_exists: Some(true), + charset: None, + collation: None, + encoding: None, + lc_collate: None, + lc_ctype: None, + }, + "analytics", + ) + .unwrap(); + assert_eq!(sql, "CREATE DATABASE IF NOT EXISTS `analytics`"); + } + + #[test] + fn clickhouse_sql_rejects_unsupported_options() { + let err = build_clickhouse_create_database_sql( + &CreateDatabasePayload { + name: "analytics".to_string(), + if_not_exists: Some(true), + charset: Some("utf8mb4".to_string()), + collation: None, + encoding: None, + lc_collate: None, + lc_ctype: None, + }, + "analytics", + ) + .unwrap_err(); + assert!(err.contains("does not support charset option")); + } } diff --git a/src-tauri/src/commands/metadata.rs b/src-tauri/src/commands/metadata.rs index a171959..f0e7c0f 100644 --- a/src-tauri/src/commands/metadata.rs +++ b/src-tauri/src/commands/metadata.rs @@ -2,6 +2,16 @@ use crate::models::{ConnectionForm, SchemaOverview, TableInfo, TableMetadata, Ta use crate::state::AppState; use tauri::State; +fn ensure_table_structure_found(structure: TableStructure, table: &str) -> Result { + if structure.columns.is_empty() { + return Err(format!( + "[NOT_FOUND] Table '{}' does not exist or has no visible columns", + table + )); + } + Ok(structure) +} + #[tauri::command] pub async fn get_schema_overview( state: State<'_, AppState>, @@ -16,6 +26,19 @@ pub async fn get_schema_overview( .await } +pub async fn get_schema_overview_direct( + state: &AppState, + id: i64, + database: Option, + schema: Option, +) -> Result { + super::execute_with_retry_from_app_state(state, id, database, |driver| { + let schema_clone = schema.clone(); + async move { driver.get_schema_overview(schema_clone).await } + }) + .await +} + #[tauri::command] pub async fn list_tables_by_conn(form: ConnectionForm) -> Result, String> { let driver = crate::db::drivers::connect(&form).await?; @@ -46,12 +69,30 @@ pub async fn get_table_structure( schema: String, table: String, ) -> Result { + let table_name = table.clone(); super::execute_with_retry(&state, id, None, |driver| { let schema_clone = schema.clone(); let table_clone = table.clone(); async move { driver.get_table_structure(schema_clone, table_clone).await } }) .await + .and_then(|structure| ensure_table_structure_found(structure, &table_name)) +} + +pub async fn get_table_structure_direct( + state: &AppState, + id: i64, + schema: String, + table: String, +) -> Result { + let table_name = table.clone(); + super::execute_with_retry_from_app_state(state, id, None, |driver| { + let schema_clone = schema.clone(); + let table_clone = table.clone(); + async move { driver.get_table_structure(schema_clone, table_clone).await } + }) + .await + .and_then(|structure| ensure_table_structure_found(structure, &table_name)) } #[tauri::command] @@ -70,6 +111,21 @@ pub async fn get_table_ddl( .await } +pub async fn get_table_ddl_direct( + state: &AppState, + id: i64, + database: Option, + schema: String, + table: String, +) -> Result { + super::execute_with_retry_from_app_state(state, id, database, |driver| { + let schema_clone = schema.clone(); + let table_clone = table.clone(); + async move { driver.get_table_ddl(schema_clone, table_clone).await } + }) + .await +} + #[tauri::command] pub async fn get_table_metadata( state: State<'_, AppState>, @@ -85,3 +141,18 @@ pub async fn get_table_metadata( }) .await } + +pub async fn get_table_metadata_direct( + state: &AppState, + id: i64, + database: Option, + schema: String, + table: String, +) -> Result { + super::execute_with_retry_from_app_state(state, id, database, |driver| { + let schema_clone = schema.clone(); + let table_clone = table.clone(); + async move { driver.get_table_metadata(schema_clone, table_clone).await } + }) + .await +} diff --git a/src-tauri/src/commands/mod.rs b/src-tauri/src/commands/mod.rs index 69fb9a9..07f08f7 100644 --- a/src-tauri/src/commands/mod.rs +++ b/src-tauri/src/commands/mod.rs @@ -67,6 +67,45 @@ pub async fn ensure_connection_with_db( state.pool_manager.connect(&key, &form).await } +pub async fn ensure_connection_with_db_from_app_state( + state: &AppState, + id: i64, + database: Option, +) -> Result, String> { + let key = connection_pool_key(id, &database); + + if let Some(driver) = state.pool_manager.get_connection(&key).await { + let local_db = { + let lock = state.local_db.lock().await; + lock.clone() + }; + + if let Some(db) = local_db { + if db.get_connection_by_id(id).await.is_err() { + state.pool_manager.remove_by_prefix(&id.to_string()).await; + return Err(format!("Connection with ID {} no longer exists", id)); + } + } + return Ok(driver); + } + + let local_db = { + let lock = state.local_db.lock().await; + lock.clone() + }; + + let db = local_db.ok_or("Local DB not initialized")?; + let mut form = db.get_connection_form_by_id(id).await?; + + if let Some(db_name) = database { + if !db_name.is_empty() { + form.database = Some(db_name); + } + } + + state.pool_manager.connect(&key, &form).await +} + async fn execute_with_retry_core( mut ensure: Ensure, mut remove: Remove, @@ -119,6 +158,25 @@ where .await } +pub async fn execute_with_retry_from_app_state( + state: &AppState, + id: i64, + database: Option, + task: F, +) -> Result +where + F: Fn(Arc) -> Fut, + Fut: std::future::Future>, +{ + let key = connection_pool_key(id, &database); + execute_with_retry_core( + || ensure_connection_with_db_from_app_state(state, id, database.clone()), + || state.pool_manager.remove(&key), + task, + ) + .await +} + fn is_connection_error(e: &str) -> bool { let lower = e.to_lowercase(); lower.contains("pool closed") diff --git a/src-tauri/src/commands/query.rs b/src-tauri/src/commands/query.rs index d103541..196a61f 100644 --- a/src-tauri/src/commands/query.rs +++ b/src-tauri/src/commands/query.rs @@ -593,6 +593,40 @@ async fn append_sql_execution_log( } } +async fn append_sql_execution_log_direct( + state: &AppState, + sql: String, + source: Option, + connection_id: Option, + database: Option, + success: bool, + error: Option, +) { + let db = { + let lock = state.local_db.lock().await; + lock.clone() + }; + + if let Some(local_db) = db { + if let Err(e) = local_db + .insert_sql_execution_log(sql, source, connection_id, database, success, error) + .await + { + eprintln!("[SQL_LOG_APPEND_ERROR] {}", e); + } + } +} + +fn validate_page_limit(page: i64, limit: i64) -> Result<(), String> { + if page <= 0 { + return Err("[VALIDATION_ERROR] page must be greater than 0".to_string()); + } + if limit <= 0 { + return Err("[VALIDATION_ERROR] limit must be greater than 0".to_string()); + } + Ok(()) +} + #[tauri::command] pub async fn get_table_data_by_conn( form: ConnectionForm, @@ -601,6 +635,7 @@ pub async fn get_table_data_by_conn( page: i64, limit: i64, ) -> Result { + validate_page_limit(page, limit)?; let driver = crate::db::drivers::connect(&form).await?; driver .get_table_data(schema, table, page, limit, None, None, None, None) @@ -691,6 +726,93 @@ pub async fn execute_query( result } +async fn resolve_driver_from_app_state(state: &AppState, id: i64) -> Option { + let db = { + let lock = state.local_db.lock().await; + lock.clone() + }?; + db.get_connection_form_by_id(id) + .await + .ok() + .map(|f| f.driver) +} + +pub async fn execute_query_by_id_direct( + state: &AppState, + id: i64, + query: String, + database: Option, + source: Option, + query_id: Option, +) -> Result { + let query_id = make_query_id(id, query_id); + let driver = resolve_driver_from_app_state(state, id).await; + let is_clickhouse = driver + .as_deref() + .map(|d| d.eq_ignore_ascii_case("clickhouse")) + .unwrap_or(false); + let guarded_query = maybe_apply_default_limit(&query, driver.as_deref()); + if is_clickhouse { + register_running_query(id, &query_id).await; + } + + let result = super::execute_with_retry_from_app_state(state, id, database.clone(), |driver| { + let query_clone = guarded_query.clone(); + let query_id_clone = query_id.clone(); + async move { + driver + .execute_query_with_id( + query_clone, + if is_clickhouse { + Some(query_id_clone.as_str()) + } else { + None + }, + ) + .await + } + }) + .await; + if is_clickhouse { + unregister_running_query(id, &query_id).await; + } + + if result.is_ok() { + append_sql_execution_log_direct( + state, + guarded_query.clone(), + source, + Some(id), + database, + true, + None, + ) + .await; + } else if let Err(err) = &result { + append_sql_execution_log_direct( + state, + guarded_query.clone(), + source, + Some(id), + database, + false, + Some(err.clone()), + ) + .await; + } + + result +} + +pub async fn execute_by_conn_direct( + form: ConnectionForm, + sql: String, +) -> Result { + let guarded_sql = maybe_apply_default_limit(&sql, Some(&form.driver)); + let driver = crate::db::drivers::connect(&form).await?; + driver.execute_query_with_id(guarded_sql, None).await +} + #[tauri::command] pub async fn get_table_data( state: State<'_, AppState>, @@ -705,6 +827,7 @@ pub async fn get_table_data( sort_direction: Option, order_by: Option, ) -> Result { + validate_page_limit(page, limit)?; super::execute_with_retry(&state, id, database, |driver| { let schema_clone = schema.clone(); let table_clone = table.clone(); @@ -849,6 +972,56 @@ pub async fn list_sql_execution_logs( } } +pub async fn list_sql_execution_logs_direct( + state: &AppState, + limit: Option, +) -> Result, String> { + let safe_limit = clamp_sql_execution_logs_limit(limit); + let local_db = { + let lock = state.local_db.lock().await; + lock.clone() + }; + + if let Some(db) = local_db { + db.list_sql_execution_logs(safe_limit).await + } else { + Err("Local DB not initialized".to_string()) + } +} + +pub async fn cancel_query_direct( + state: &AppState, + uuid: String, + query_id: String, +) -> Result { + let connection_id = uuid + .trim() + .parse::() + .map_err(|_| "[VALIDATION_ERROR] Invalid connection id for cancellation".to_string())?; + let query_id = query_id.trim().to_string(); + if query_id.is_empty() { + return Err("[VALIDATION_ERROR] query_id cannot be empty".to_string()); + } + if !is_running_query(connection_id, &query_id).await { + return Ok(false); + } + + let local_db = { + let lock = state.local_db.lock().await; + lock.clone() + }; + let db = local_db.ok_or("Local DB not initialized".to_string())?; + let form = db.get_connection_form_by_id(connection_id).await?; + if !form.driver.eq_ignore_ascii_case("clickhouse") { + return Ok(false); + } + + let driver = crate::db::drivers::clickhouse::ClickHouseDriver::connect(&form).await?; + driver.kill_query(&query_id).await?; + unregister_running_query(connection_id, &query_id).await; + Ok(true) +} + #[cfg(test)] mod tests { use super::{ @@ -1054,9 +1227,8 @@ mod tests { #[test] fn collect_top_level_keywords_skips_subqueries_and_strings() { - let tokens = collect_top_level_keywords( - "WITH cte AS (SELECT 'from' AS v) SELECT * FROM cte", - ); + let tokens = + collect_top_level_keywords("WITH cte AS (SELECT 'from' AS v) SELECT * FROM cte"); assert_eq!(tokens.first().map(String::as_str), Some("with")); assert!(tokens.contains(&"select".to_string())); assert!(tokens.contains(&"from".to_string())); diff --git a/src-tauri/src/commands/storage.rs b/src-tauri/src/commands/storage.rs index 7824e6e..39ec88f 100644 --- a/src-tauri/src/commands/storage.rs +++ b/src-tauri/src/commands/storage.rs @@ -20,6 +20,23 @@ pub async fn save_query( } } +pub async fn save_query_direct( + state: &AppState, + name: String, + query: String, + description: Option, + connection_id: Option, + database: Option, +) -> Result { + let local_db = state.local_db.lock().await; + if let Some(db) = local_db.as_ref() { + db.create_saved_query(name, query, description, connection_id, database) + .await + } else { + Err("Local DB not initialized".to_string()) + } +} + #[tauri::command] pub async fn update_saved_query( state: State<'_, AppState>, @@ -39,6 +56,24 @@ pub async fn update_saved_query( } } +pub async fn update_saved_query_direct( + state: &AppState, + id: i64, + name: String, + query: String, + description: Option, + connection_id: Option, + database: Option, +) -> Result { + let local_db = state.local_db.lock().await; + if let Some(db) = local_db.as_ref() { + db.update_saved_query(id, name, query, description, connection_id, database) + .await + } else { + Err("Local DB not initialized".to_string()) + } +} + #[tauri::command] pub async fn delete_saved_query(state: State<'_, AppState>, id: i64) -> Result<(), String> { let local_db = state.local_db.lock().await; @@ -49,6 +84,15 @@ pub async fn delete_saved_query(state: State<'_, AppState>, id: i64) -> Result<( } } +pub async fn delete_saved_query_direct(state: &AppState, id: i64) -> Result<(), String> { + let local_db = state.local_db.lock().await; + if let Some(db) = local_db.as_ref() { + db.delete_saved_query(id).await + } else { + Err("Local DB not initialized".to_string()) + } +} + #[tauri::command] pub async fn get_saved_queries(state: State<'_, AppState>) -> Result, String> { let local_db = state.local_db.lock().await; @@ -58,3 +102,12 @@ pub async fn get_saved_queries(state: State<'_, AppState>) -> Result Result, String> { + let local_db = state.local_db.lock().await; + if let Some(db) = local_db.as_ref() { + db.list_saved_queries().await + } else { + Err("Local DB not initialized".to_string()) + } +} diff --git a/src-tauri/src/commands/transfer.rs b/src-tauri/src/commands/transfer.rs index 067152b..94bd044 100644 --- a/src-tauri/src/commands/transfer.rs +++ b/src-tauri/src/commands/transfer.rs @@ -3,10 +3,12 @@ use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use std::fs::{self, File}; use std::io::{BufWriter, Write}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use tauri::State; const DEFAULT_CHUNK_SIZE: i64 = 2000; +const MAX_IMPORT_FILE_SIZE_BYTES: u64 = 20 * 1024 * 1024; +const MAX_IMPORT_STATEMENTS: usize = 50_000; #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "snake_case")] @@ -31,6 +33,33 @@ pub struct ExportResult { pub row_count: i64, } +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ImportSqlResult { + pub file_path: String, + pub total_statements: i64, + pub success_statements: i64, + pub failed_at: Option, + pub failed_batch: Option, + pub failed_statement_preview: Option, + pub error: Option, + pub time_taken_ms: i64, + pub rolled_back: bool, +} + +#[derive(Debug, Clone)] +struct ImportExecutionUnit { + sql: String, + batch_index: usize, + preview: String, +} + +#[derive(Debug, Clone)] +struct PreparedImportPlan { + units: Vec, + script_managed_transaction: bool, +} + #[tauri::command] pub async fn export_table_data( state: State<'_, AppState>, @@ -162,6 +191,135 @@ pub async fn export_table_data( .await } +pub async fn export_table_data_direct( + state: &AppState, + id: i64, + database: Option, + schema: String, + table: String, + driver: String, + format: ExportFormat, + scope: ExportScope, + filter: Option, + order_by: Option, + sort_column: Option, + sort_direction: Option, + page: Option, + limit: Option, + file_path: Option, + chunk_size: Option, +) -> Result { + let output_path = resolve_output_path(file_path, &table, extension_for_format(&format))?; + let chunk = chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE).max(1); + + super::execute_with_retry_from_app_state(state, id, database, |db_driver| { + let output_path = output_path.clone(); + let schema = schema.clone(); + let table = table.clone(); + let driver = driver.clone(); + let filter = filter.clone(); + let order_by = order_by.clone(); + let sort_column = sort_column.clone(); + let sort_direction = sort_direction.clone(); + let scope = scope.clone(); + let format = format.clone(); + async move { + let columns = db_driver + .get_table_metadata(schema.clone(), table.clone()) + .await? + .columns + .into_iter() + .map(|c| c.name) + .collect::>(); + + let mut writer = + ExportWriter::new(output_path.clone(), format.clone(), columns.clone())?; + let mut exported = 0i64; + + match scope { + ExportScope::CurrentPage => { + let use_page = page.unwrap_or(1).max(1); + let use_limit = limit.unwrap_or(50).max(1); + let resp = db_driver + .get_table_data_chunk( + schema.clone(), + table.clone(), + use_page, + use_limit, + sort_column.clone(), + sort_direction.clone(), + filter.clone(), + order_by.clone(), + ) + .await?; + exported += + writer.write_rows(&resp.data, &columns, Some(&schema), &table, &driver)?; + } + ExportScope::Filtered | ExportScope::FullTable => { + let filter_for_scope = if matches!(scope, ExportScope::Filtered) { + filter.clone() + } else { + None + }; + let order_for_scope = if matches!(scope, ExportScope::Filtered) { + order_by.clone() + } else { + None + }; + let sort_col_for_scope = if matches!(scope, ExportScope::Filtered) { + sort_column.clone() + } else { + None + }; + let sort_dir_for_scope = if matches!(scope, ExportScope::Filtered) { + sort_direction.clone() + } else { + None + }; + + let mut current_page = 1; + loop { + let resp = db_driver + .get_table_data_chunk( + schema.clone(), + table.clone(), + current_page, + chunk, + sort_col_for_scope.clone(), + sort_dir_for_scope.clone(), + filter_for_scope.clone(), + order_for_scope.clone(), + ) + .await?; + if resp.data.is_empty() { + break; + } + + exported += writer.write_rows( + &resp.data, + &columns, + Some(&schema), + &table, + &driver, + )?; + if exported >= resp.total { + break; + } + current_page += 1; + } + } + } + + writer.finish()?; + Ok(ExportResult { + file_path: output_path.to_string_lossy().to_string(), + row_count: exported, + }) + } + }) + .await +} + #[tauri::command] pub async fn export_query_result( state: State<'_, AppState>, @@ -200,6 +358,548 @@ pub async fn export_query_result( .await } +pub async fn export_query_result_direct( + state: &AppState, + id: i64, + database: Option, + sql: String, + driver: String, + format: ExportFormat, + file_path: Option, +) -> Result { + let output_path = + resolve_output_path(file_path, "query_result", extension_for_format(&format))?; + + super::execute_with_retry_from_app_state(state, id, database, |db_driver| { + let output_path = output_path.clone(); + let driver = driver.clone(); + let sql = sql.clone(); + let format = format.clone(); + async move { + let result = db_driver.execute_query(sql).await?; + let columns = result + .columns + .into_iter() + .map(|c| c.name) + .collect::>(); + let mut writer = ExportWriter::new(output_path.clone(), format, columns.clone())?; + let exported = writer.write_rows(&result.data, &columns, None, "query_result", &driver)?; + writer.finish()?; + Ok(ExportResult { + file_path: output_path.to_string_lossy().to_string(), + row_count: exported, + }) + } + }) + .await +} + +#[tauri::command] +pub async fn import_sql_file( + state: State<'_, AppState>, + id: i64, + database: Option, + file_path: String, + driver: String, +) -> Result { + let normalized_driver = normalize_driver_name(&driver); + let (begin_sql, commit_sql, rollback_sql) = + import_transaction_sql(&normalized_driver, &driver)?; + + let import_path = PathBuf::from(file_path.trim()); + validate_import_path(&import_path)?; + validate_import_file_size(&import_path)?; + + let source = fs::read_to_string(&import_path) + .map_err(|e| format!("[IMPORT_ERROR] failed to read sql file: {e}"))?; + let source = source + .strip_prefix('\u{feff}') + .unwrap_or(&source) + .to_string(); + + let import_plan = prepare_import_plan(&source, &normalized_driver)?; + if import_plan.units.is_empty() { + return Err("[IMPORT_ERROR] SQL file does not contain executable statements".to_string()); + } + if import_plan.units.len() > MAX_IMPORT_STATEMENTS { + return Err(format!( + "[IMPORT_ERROR] statement count exceeds limit ({} > {})", + import_plan.units.len(), + MAX_IMPORT_STATEMENTS + )); + } + + let started_at = std::time::Instant::now(); + let total_statements = import_plan.units.len() as i64; + let use_outer_transaction = !import_plan.script_managed_transaction; + + super::execute_with_retry(&state, id, database, |db_driver| { + let import_plan = import_plan.clone(); + let import_path = import_path.clone(); + async move { + if use_outer_transaction { + db_driver + .execute_query(begin_sql.to_string()) + .await + .map_err(|e| format!("[IMPORT_ERROR] failed to start transaction: {e}"))?; + } + + let mut success_statements = 0i64; + for (idx, unit) in import_plan.units.iter().enumerate() { + if let Err(e) = db_driver.execute_query(unit.sql.clone()).await { + if use_outer_transaction { + let _ = db_driver.execute_query(rollback_sql.to_string()).await; + } + return Ok(ImportSqlResult { + file_path: import_path.to_string_lossy().to_string(), + total_statements, + success_statements, + failed_at: Some((idx + 1) as i64), + failed_batch: Some(unit.batch_index as i64), + failed_statement_preview: Some(unit.preview.clone()), + error: Some(truncate_error_message(&e)), + time_taken_ms: started_at.elapsed().as_millis() as i64, + rolled_back: use_outer_transaction, + }); + } + success_statements += 1; + } + + if use_outer_transaction { + if let Err(e) = db_driver.execute_query(commit_sql.to_string()).await { + let _ = db_driver.execute_query(rollback_sql.to_string()).await; + return Ok(ImportSqlResult { + file_path: import_path.to_string_lossy().to_string(), + total_statements, + success_statements, + failed_at: None, + failed_batch: None, + failed_statement_preview: None, + error: Some(format!( + "[IMPORT_ERROR] failed to commit transaction: {}", + truncate_error_message(&e) + )), + time_taken_ms: started_at.elapsed().as_millis() as i64, + rolled_back: true, + }); + } + } + + Ok(ImportSqlResult { + file_path: import_path.to_string_lossy().to_string(), + total_statements, + success_statements: total_statements, + failed_at: None, + failed_batch: None, + failed_statement_preview: None, + error: None, + time_taken_ms: started_at.elapsed().as_millis() as i64, + rolled_back: false, + }) + } + }) + .await +} + +pub async fn import_sql_file_direct( + state: &AppState, + id: i64, + database: Option, + file_path: String, + driver: String, +) -> Result { + let normalized_driver = normalize_driver_name(&driver); + let (begin_sql, commit_sql, rollback_sql) = + import_transaction_sql(&normalized_driver, &driver)?; + + let import_path = PathBuf::from(file_path.trim()); + validate_import_path(&import_path)?; + validate_import_file_size(&import_path)?; + + let source = fs::read_to_string(&import_path) + .map_err(|e| format!("[IMPORT_ERROR] failed to read sql file: {e}"))?; + let source = source + .strip_prefix('\u{feff}') + .unwrap_or(&source) + .to_string(); + + let import_plan = prepare_import_plan(&source, &normalized_driver)?; + if import_plan.units.is_empty() { + return Err("[IMPORT_ERROR] SQL file does not contain executable statements".to_string()); + } + if import_plan.units.len() > MAX_IMPORT_STATEMENTS { + return Err(format!( + "[IMPORT_ERROR] statement count exceeds limit ({} > {})", + import_plan.units.len(), + MAX_IMPORT_STATEMENTS + )); + } + + let started_at = std::time::Instant::now(); + let total_statements = import_plan.units.len() as i64; + let use_outer_transaction = !import_plan.script_managed_transaction; + + super::execute_with_retry_from_app_state(state, id, database, |db_driver| { + let import_plan = import_plan.clone(); + let import_path = import_path.clone(); + async move { + if use_outer_transaction { + db_driver + .execute_query(begin_sql.to_string()) + .await + .map_err(|e| format!("[IMPORT_ERROR] failed to start transaction: {e}"))?; + } + + let mut success_statements = 0i64; + for (idx, unit) in import_plan.units.iter().enumerate() { + if let Err(e) = db_driver.execute_query(unit.sql.clone()).await { + if use_outer_transaction { + let _ = db_driver.execute_query(rollback_sql.to_string()).await; + } + return Ok(ImportSqlResult { + file_path: import_path.to_string_lossy().to_string(), + total_statements, + success_statements, + failed_at: Some((idx + 1) as i64), + failed_batch: Some(unit.batch_index as i64), + failed_statement_preview: Some(unit.preview.clone()), + error: Some(truncate_error_message(&e)), + time_taken_ms: started_at.elapsed().as_millis() as i64, + rolled_back: use_outer_transaction, + }); + } + success_statements += 1; + } + + if use_outer_transaction { + if let Err(e) = db_driver.execute_query(commit_sql.to_string()).await { + let _ = db_driver.execute_query(rollback_sql.to_string()).await; + return Ok(ImportSqlResult { + file_path: import_path.to_string_lossy().to_string(), + total_statements, + success_statements, + failed_at: None, + failed_batch: None, + failed_statement_preview: None, + error: Some(format!( + "[IMPORT_ERROR] failed to commit transaction: {}", + truncate_error_message(&e) + )), + time_taken_ms: started_at.elapsed().as_millis() as i64, + rolled_back: true, + }); + } + } + + Ok(ImportSqlResult { + file_path: import_path.to_string_lossy().to_string(), + total_statements, + success_statements: total_statements, + failed_at: None, + failed_batch: None, + failed_statement_preview: None, + error: None, + time_taken_ms: started_at.elapsed().as_millis() as i64, + rolled_back: false, + }) + } + }) + .await +} + +fn import_transaction_sql<'a>( + normalized_driver: &'a str, + original_driver: &str, +) -> Result<(&'a str, &'a str, &'a str), String> { + match normalized_driver { + "mysql" | "mariadb" | "tidb" => Ok(("START TRANSACTION", "COMMIT", "ROLLBACK")), + "postgres" | "sqlite" | "duckdb" => Ok(("BEGIN", "COMMIT", "ROLLBACK")), + "mssql" => Ok(( + "BEGIN TRANSACTION", + "COMMIT TRANSACTION", + "ROLLBACK TRANSACTION", + )), + "clickhouse" => { + Err("[UNSUPPORTED] Driver clickhouse is read-only in this import flow".to_string()) + } + _ => Err(format!( + "[UNSUPPORTED] Driver {} is not supported for SQL import", + original_driver + )), + } +} + +fn normalize_driver_name(driver: &str) -> String { + let normalized = driver.trim().to_ascii_lowercase(); + match normalized.as_str() { + "postgresql" | "pgsql" => "postgres".to_string(), + _ => normalized, + } +} + +fn prepare_import_plan(sql: &str, normalized_driver: &str) -> Result { + let units = if normalized_driver == "mssql" { + let batches = parse_mssql_batches(sql)?; + batches + .into_iter() + .enumerate() + .map(|(idx, batch)| ImportExecutionUnit { + preview: build_statement_preview(&batch), + sql: batch, + batch_index: idx + 1, + }) + .collect::>() + } else { + parse_sql_statements(sql, normalized_driver)? + .into_iter() + .enumerate() + .map(|(idx, statement)| ImportExecutionUnit { + preview: build_statement_preview(&statement), + sql: statement, + batch_index: idx + 1, + }) + .collect::>() + }; + + let script_managed_transaction = units + .iter() + .any(|unit| statement_controls_transaction(&unit.sql, normalized_driver)); + + Ok(PreparedImportPlan { + units, + script_managed_transaction, + }) +} + +fn build_statement_preview(statement: &str) -> String { + let compact = statement.split_whitespace().collect::>().join(" "); + let mut preview = String::new(); + for (idx, ch) in compact.chars().enumerate() { + if idx >= 160 { + preview.push_str("..."); + break; + } + preview.push(ch); + } + if preview.is_empty() { + "".to_string() + } else { + preview + } +} + +fn leading_sql_tokens(sql: &str, max_tokens: usize) -> Vec { + let chars: Vec = sql.chars().collect(); + let mut tokens = Vec::new(); + let mut i = 0usize; + + while i < chars.len() && tokens.len() < max_tokens { + let ch = chars[i]; + let next = chars.get(i + 1).copied(); + + if ch.is_whitespace() || ch == ';' { + i += 1; + continue; + } + + if ch == '-' && next == Some('-') { + i += 2; + while i < chars.len() && chars[i] != '\n' { + i += 1; + } + continue; + } + + if ch == '/' && next == Some('*') { + i += 2; + while i + 1 < chars.len() && !(chars[i] == '*' && chars[i + 1] == '/') { + i += 1; + } + if i + 1 < chars.len() { + i += 2; + } + continue; + } + + if ch.is_ascii_alphabetic() { + let start = i; + i += 1; + while i < chars.len() && (chars[i].is_ascii_alphabetic() || chars[i] == '_') { + i += 1; + } + tokens.push( + chars[start..i] + .iter() + .collect::() + .to_ascii_lowercase(), + ); + continue; + } + + i += 1; + } + + tokens +} + +fn statement_controls_transaction(statement: &str, normalized_driver: &str) -> bool { + let tokens = leading_sql_tokens(statement, 2); + if tokens.is_empty() { + return false; + } + + let first = tokens[0].as_str(); + let second = tokens.get(1).map(|s| s.as_str()).unwrap_or(""); + + match first { + "commit" | "rollback" => true, + "start" => second == "transaction", + "begin" => { + if normalized_driver == "mssql" { + second == "transaction" || second == "tran" + } else { + true + } + } + _ => false, + } +} + +fn parse_mssql_go_line_count(line: &str) -> Option { + let trimmed = line.trim(); + let prefix = trimmed.get(..2)?; + if !prefix.eq_ignore_ascii_case("go") { + return None; + } + let rest = trimmed[2..].trim(); + if rest.is_empty() { + return Some(1); + } + if rest.chars().all(|ch| ch.is_ascii_digit()) { + let count = rest.parse::().ok()?; + if count > 0 { + return Some(count); + } + } + None +} + +fn update_mssql_line_state(state: &mut SqlScanState, line: &str) { + let chars: Vec = line.chars().collect(); + let mut i = 0usize; + + while i < chars.len() { + match state { + SqlScanState::Normal => { + let ch = chars[i]; + let next = chars.get(i + 1).copied(); + if ch == '-' && next == Some('-') { + *state = SqlScanState::LineComment; + break; + } + if ch == '/' && next == Some('*') { + *state = SqlScanState::BlockComment; + i += 2; + continue; + } + if ch == '\'' { + *state = SqlScanState::SingleQuoted; + i += 1; + continue; + } + if ch == '"' { + *state = SqlScanState::DoubleQuoted; + i += 1; + continue; + } + i += 1; + } + SqlScanState::SingleQuoted => { + if chars[i] == '\'' { + if chars.get(i + 1) == Some(&'\'') { + i += 2; + continue; + } + *state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::DoubleQuoted => { + if chars[i] == '"' { + if chars.get(i + 1) == Some(&'"') { + i += 2; + continue; + } + *state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::BlockComment => { + if chars[i] == '*' && chars.get(i + 1) == Some(&'/') { + *state = SqlScanState::Normal; + i += 2; + continue; + } + i += 1; + } + SqlScanState::LineComment => { + break; + } + SqlScanState::BacktickQuoted | SqlScanState::DollarQuoted(_) => { + *state = SqlScanState::Normal; + } + } + } + + if matches!(state, SqlScanState::LineComment) { + *state = SqlScanState::Normal; + } +} + +fn parse_mssql_batches(sql: &str) -> Result, String> { + let mut out = Vec::new(); + let mut current = String::new(); + let mut state = SqlScanState::Normal; + + for line in sql.split_inclusive('\n') { + if matches!(state, SqlScanState::Normal) { + let plain_line = line.trim_end_matches(|ch| ch == '\r' || ch == '\n'); + if let Some(go_count) = parse_mssql_go_line_count(plain_line) { + let statement = current.trim(); + if !statement.is_empty() { + for _ in 0..go_count { + out.push(statement.to_string()); + } + } + current.clear(); + continue; + } + } + + update_mssql_line_state(&mut state, line); + current.push_str(line); + } + + match state { + SqlScanState::Normal | SqlScanState::LineComment => {} + SqlScanState::BlockComment => { + return Err("[IMPORT_ERROR] Unterminated block comment in SQL file".to_string()); + } + SqlScanState::SingleQuoted + | SqlScanState::DoubleQuoted + | SqlScanState::BacktickQuoted + | SqlScanState::DollarQuoted(_) => { + return Err("[IMPORT_ERROR] Unterminated string literal in SQL file".to_string()); + } + } + + let tail = current.trim(); + if !tail.is_empty() { + out.push(tail.to_string()); + } + Ok(out) +} + fn extension_for_format(format: &ExportFormat) -> &'static str { match format { ExportFormat::Csv => "csv", @@ -232,6 +932,250 @@ fn resolve_output_path( Ok(path) } +fn validate_import_path(path: &Path) -> Result<(), String> { + if path.as_os_str().is_empty() { + return Err("[IMPORT_ERROR] Invalid import path".to_string()); + } + if path.is_dir() { + return Err("[IMPORT_ERROR] Import path points to a directory".to_string()); + } + if !path.exists() { + return Err("[IMPORT_ERROR] Import file does not exist".to_string()); + } + let Some(ext) = path.extension().and_then(|v| v.to_str()) else { + return Err("[IMPORT_ERROR] Import file must use .sql extension".to_string()); + }; + if !ext.eq_ignore_ascii_case("sql") { + return Err("[IMPORT_ERROR] Import file must use .sql extension".to_string()); + } + Ok(()) +} + +fn validate_import_file_size(path: &Path) -> Result<(), String> { + let metadata = fs::metadata(path) + .map_err(|e| format!("[IMPORT_ERROR] failed to read file metadata: {e}"))?; + if metadata.len() > MAX_IMPORT_FILE_SIZE_BYTES { + return Err(format!( + "[IMPORT_ERROR] file is too large (max {} bytes)", + MAX_IMPORT_FILE_SIZE_BYTES + )); + } + Ok(()) +} + +#[derive(Debug, Clone)] +enum SqlScanState { + Normal, + SingleQuoted, + DoubleQuoted, + BacktickQuoted, + DollarQuoted(String), + LineComment, + BlockComment, +} + +fn parse_sql_statements(sql: &str, driver: &str) -> Result, String> { + let mysql_style_hash_comment = matches!(driver, "mysql" | "mariadb" | "tidb"); + let chars: Vec = sql.chars().collect(); + let mut out = Vec::new(); + let mut current = String::new(); + let mut state = SqlScanState::Normal; + let mut i = 0usize; + + while i < chars.len() { + match &state { + SqlScanState::Normal => { + let ch = chars[i]; + let next = chars.get(i + 1).copied(); + + if ch == '-' && next == Some('-') { + state = SqlScanState::LineComment; + i += 2; + continue; + } + if mysql_style_hash_comment && ch == '#' { + state = SqlScanState::LineComment; + i += 1; + continue; + } + if ch == '/' && next == Some('*') { + state = SqlScanState::BlockComment; + i += 2; + continue; + } + if ch == '\'' { + current.push(ch); + state = SqlScanState::SingleQuoted; + i += 1; + continue; + } + if ch == '"' { + current.push(ch); + state = SqlScanState::DoubleQuoted; + i += 1; + continue; + } + if ch == '`' { + current.push(ch); + state = SqlScanState::BacktickQuoted; + i += 1; + continue; + } + if ch == '$' { + if let Some((tag, end_idx)) = parse_dollar_quote_tag(&chars, i) { + current.push_str(&tag); + state = SqlScanState::DollarQuoted(tag); + i = end_idx + 1; + continue; + } + } + if ch == ';' { + let statement = current.trim(); + if !statement.is_empty() { + out.push(statement.to_string()); + } + current.clear(); + i += 1; + continue; + } + current.push(ch); + i += 1; + } + SqlScanState::SingleQuoted => { + let ch = chars[i]; + current.push(ch); + if ch == '\\' { + if let Some(next) = chars.get(i + 1) { + current.push(*next); + i += 2; + continue; + } + } + if ch == '\'' { + if chars.get(i + 1) == Some(&'\'') { + current.push('\''); + i += 2; + continue; + } + state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::DoubleQuoted => { + let ch = chars[i]; + current.push(ch); + if ch == '"' { + if chars.get(i + 1) == Some(&'"') { + current.push('"'); + i += 2; + continue; + } + state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::BacktickQuoted => { + let ch = chars[i]; + current.push(ch); + if ch == '`' { + if chars.get(i + 1) == Some(&'`') { + current.push('`'); + i += 2; + continue; + } + state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::DollarQuoted(tag) => { + if starts_with_tag(&chars, i, tag) { + current.push_str(tag); + i += tag.chars().count(); + state = SqlScanState::Normal; + continue; + } + current.push(chars[i]); + i += 1; + } + SqlScanState::LineComment => { + if chars[i] == '\n' { + current.push('\n'); + state = SqlScanState::Normal; + } + i += 1; + } + SqlScanState::BlockComment => { + if chars[i] == '*' && chars.get(i + 1) == Some(&'/') { + state = SqlScanState::Normal; + i += 2; + } else { + i += 1; + } + } + } + } + + match state { + SqlScanState::Normal | SqlScanState::LineComment => {} + SqlScanState::BlockComment => { + return Err("[IMPORT_ERROR] Unterminated block comment in SQL file".to_string()); + } + SqlScanState::SingleQuoted + | SqlScanState::DoubleQuoted + | SqlScanState::BacktickQuoted + | SqlScanState::DollarQuoted(_) => { + return Err("[IMPORT_ERROR] Unterminated string literal in SQL file".to_string()); + } + } + + let tail = current.trim(); + if !tail.is_empty() { + out.push(tail.to_string()); + } + Ok(out) +} + +fn parse_dollar_quote_tag(chars: &[char], start: usize) -> Option<(String, usize)> { + if chars.get(start) != Some(&'$') { + return None; + } + let mut idx = start + 1; + while idx < chars.len() && (chars[idx].is_ascii_alphanumeric() || chars[idx] == '_') { + idx += 1; + } + if idx < chars.len() && chars[idx] == '$' { + let tag: String = chars[start..=idx].iter().collect(); + return Some((tag, idx)); + } + None +} + +fn starts_with_tag(chars: &[char], idx: usize, tag: &str) -> bool { + let tag_chars: Vec = tag.chars().collect(); + if idx + tag_chars.len() > chars.len() { + return false; + } + for (offset, ch) in tag_chars.iter().enumerate() { + if chars[idx + offset] != *ch { + return false; + } + } + true +} + +fn truncate_error_message(message: &str) -> String { + const MAX_CHARS: usize = 500; + let mut out = String::new(); + for (idx, ch) in message.chars().enumerate() { + if idx >= MAX_CHARS { + out.push_str("..."); + break; + } + out.push(ch); + } + out +} + fn validate_output_path(path: &PathBuf) -> Result<(), String> { if path.as_os_str().is_empty() { return Err("[EXPORT_ERROR] Invalid output path".to_string()); @@ -508,7 +1452,10 @@ mod tests { sql_value(&Value::String("O'Reilly".to_string())), "'O''Reilly'" ); - assert_eq!(sql_value(&Value::Number(serde_json::Number::from(42))), "42"); + assert_eq!( + sql_value(&Value::Number(serde_json::Number::from(42))), + "42" + ); assert_eq!(sql_value(&Value::Bool(false)), "FALSE"); } @@ -613,4 +1560,132 @@ mod tests { assert_eq!(err, "[EXPORT_ERROR] row is not a JSON object"); let _ = fs::remove_file(path); } + + #[test] + fn parse_sql_statements_handles_quotes_and_comments() { + let sql = r#" + -- comment 1 + INSERT INTO users (name, note) VALUES ('alice', 'hello;world'); + /* block comment ; ; */ + INSERT INTO users (name) VALUES ("bob"); + # mysql style comment + INSERT INTO users(name) VALUES ($tag$semi;inside$tag$); + "#; + + let statements = parse_sql_statements(sql, "mysql").unwrap(); + assert_eq!(statements.len(), 3); + assert!(statements[0].starts_with("INSERT INTO users")); + assert!(statements[1].contains("\"bob\"")); + assert!(statements[2].contains("$tag$semi;inside$tag$")); + } + + #[test] + fn parse_sql_statements_rejects_unterminated_block_comment() { + let err = parse_sql_statements("INSERT INTO t VALUES (1); /*", "mysql").unwrap_err(); + assert!(err.contains("Unterminated block comment")); + } + + #[test] + fn parse_sql_statements_preserves_hash_for_postgres() { + let sql = "SELECT 1 # 2;\nSELECT '#not_comment';"; + let statements = parse_sql_statements(sql, "postgres").unwrap(); + assert_eq!(statements.len(), 2); + assert_eq!(statements[0], "SELECT 1 # 2"); + assert_eq!(statements[1], "SELECT '#not_comment'"); + } + + #[test] + fn parse_mssql_batches_splits_on_go_lines_only() { + let sql = r#" + SELECT 1; + GO + SELECT 'GO should stay in string'; + -- GO in comment should not split + SELECT 2; + GO + /* GO in block comment + GO + */ + SELECT 3; + "#; + + let batches = parse_mssql_batches(sql).unwrap(); + assert_eq!(batches.len(), 3); + assert!(batches[0].contains("SELECT 1")); + assert!(batches[1].contains("SELECT 'GO should stay in string'")); + assert!(batches[2].contains("SELECT 3")); + } + + #[test] + fn parse_mssql_batches_supports_go_repeat_count() { + let sql = "SELECT 1\nGO 3\nSELECT 2\nGO"; + let batches = parse_mssql_batches(sql).unwrap(); + assert_eq!(batches.len(), 4); + assert_eq!(batches[0], "SELECT 1"); + assert_eq!(batches[1], "SELECT 1"); + assert_eq!(batches[2], "SELECT 1"); + assert_eq!(batches[3], "SELECT 2"); + } + + #[test] + fn statement_controls_transaction_detects_driver_specific_tokens() { + assert!(statement_controls_transaction("BEGIN TRANSACTION", "mssql")); + assert!(!statement_controls_transaction("BEGIN TRY", "mssql")); + assert!(statement_controls_transaction("BEGIN", "sqlite")); + assert!(statement_controls_transaction("START TRANSACTION", "mysql")); + assert!(statement_controls_transaction("ROLLBACK", "postgres")); + } + + #[test] + fn prepare_import_plan_disables_outer_tx_when_script_controls_it() { + let sqlite_plan = + prepare_import_plan("BEGIN;\nCREATE TABLE t(id INTEGER);\nCOMMIT;", "sqlite").unwrap(); + assert_eq!(sqlite_plan.units.len(), 3); + assert!(sqlite_plan.script_managed_transaction); + + let mssql_plan = prepare_import_plan("SELECT 1\nGO\nSELECT 2", "mssql").unwrap(); + assert_eq!(mssql_plan.units.len(), 2); + assert!(!mssql_plan.script_managed_transaction); + } + + #[test] + fn import_transaction_sql_maps_per_driver() { + assert_eq!( + import_transaction_sql("mysql", "mysql").unwrap(), + ("START TRANSACTION", "COMMIT", "ROLLBACK") + ); + assert_eq!( + import_transaction_sql("postgres", "postgres").unwrap(), + ("BEGIN", "COMMIT", "ROLLBACK") + ); + assert_eq!( + import_transaction_sql("postgres", "postgresql").unwrap(), + ("BEGIN", "COMMIT", "ROLLBACK") + ); + assert_eq!( + import_transaction_sql("mssql", "mssql").unwrap(), + ( + "BEGIN TRANSACTION", + "COMMIT TRANSACTION", + "ROLLBACK TRANSACTION" + ) + ); + assert!(import_transaction_sql("clickhouse", "clickhouse").is_err()); + } + + #[test] + fn normalize_driver_name_maps_aliases() { + assert_eq!(normalize_driver_name("postgres"), "postgres"); + assert_eq!(normalize_driver_name("postgresql"), "postgres"); + assert_eq!(normalize_driver_name("pgsql"), "postgres"); + assert_eq!(normalize_driver_name("mysql"), "mysql"); + } + + #[test] + fn truncate_error_message_caps_length() { + let source = "x".repeat(600); + let truncated = truncate_error_message(&source); + assert!(truncated.len() <= 503); + assert!(truncated.ends_with("...")); + } } diff --git a/src-tauri/src/connection_input/mod.rs b/src-tauri/src/connection_input/mod.rs index 8e5af95..8a347fb 100644 --- a/src-tauri/src/connection_input/mod.rs +++ b/src-tauri/src/connection_input/mod.rs @@ -6,6 +6,10 @@ fn trim_to_option(value: Option) -> Option { .and_then(|v| if v.is_empty() { None } else { Some(v) }) } +fn trim_preserve_empty(value: Option) -> Option { + value.map(|v| v.trim().to_string()) +} + fn parse_host_embedded_port(host: &str, fallback_port: Option) -> (String, Option) { if host.starts_with('[') || host.contains(' ') || host.matches(':').count() != 1 { return (host.to_string(), fallback_port); @@ -42,12 +46,12 @@ pub fn normalize_connection_form(mut form: ConnectionForm) -> Result Result { - let value = row - .and_then(|v| v.get(key)) - .ok_or_else(|| format!("[PARSE_ERROR] Missing '{}' in response for SQL: {}", key, context_sql))?; + let value = row.and_then(|v| v.get(key)).ok_or_else(|| { + format!( + "[PARSE_ERROR] Missing '{}' in response for SQL: {}", + key, context_sql + ) + })?; value_to_i64(value).ok_or_else(|| { format!( "[PARSE_ERROR] Invalid '{}' value {:?} for SQL: {}", diff --git a/src-tauri/src/db/drivers/duckdb.rs b/src-tauri/src/db/drivers/duckdb.rs index 0c1ac48..0238483 100644 --- a/src-tauri/src/db/drivers/duckdb.rs +++ b/src-tauri/src/db/drivers/duckdb.rs @@ -859,7 +859,10 @@ mod tests { .unwrap(); assert_eq!(returning_result.row_count, 1); assert_eq!(returning_result.columns.len(), 2); - assert_eq!(returning_result.data[0]["id"], serde_json::Value::String("3".to_string())); + assert_eq!( + returning_result.data[0]["id"], + serde_json::Value::String("3".to_string()) + ); assert_eq!( returning_result.data[0]["name"], serde_json::Value::String("c".to_string()) @@ -871,7 +874,10 @@ mod tests { #[test] fn test_number_from_f64_nan_and_inf_are_stringified() { - assert_eq!(number_from_f64(f64::NAN), serde_json::Value::String("NaN".to_string())); + assert_eq!( + number_from_f64(f64::NAN), + serde_json::Value::String("NaN".to_string()) + ); assert_eq!( number_from_f64(f64::INFINITY), serde_json::Value::String("inf".to_string()) diff --git a/src-tauri/src/db/drivers/mod.rs b/src-tauri/src/db/drivers/mod.rs index 442041f..f83d407 100644 --- a/src-tauri/src/db/drivers/mod.rs +++ b/src-tauri/src/db/drivers/mod.rs @@ -17,6 +17,50 @@ pub mod mysql; pub mod postgres; pub mod sqlite; +/// Build a `[CONN_FAILED]` error message with a context-aware hint derived from the +/// underlying error text, so users are not misled by a generic credential warning +/// when the actual problem is TLS incompatibility, a network issue, etc. +pub(crate) fn conn_failed_error(e: &dyn std::fmt::Display) -> String { + let raw = e.to_string(); + let lower = raw.to_ascii_lowercase(); + + let hint = if lower.contains("handshake") + || lower.contains("fatal alert") + || lower.contains("tls") + || lower.contains("ssl") + || lower.contains("certificate") + { + "hint: TLS/SSL handshake failed — the server may use a TLS version or cipher suite \ + incompatible with the client (TLS 1.2+ required); try disabling SSL in the connection settings" + } else if lower.contains("access denied") + || lower.contains("authentication") + || lower.contains("password") + || lower.contains("login failed") + || lower.contains("invalid password") + || lower.contains("1045") + { + "hint: authentication failed — verify the username/password are correct; \ + if they contain special characters they must be URL-encoded" + } else if lower.contains("connection refused") + || lower.contains("timed out") + || lower.contains("timeout") + || lower.contains("broken pipe") + || lower.contains("network unreachable") + { + "hint: could not reach the server — check host, port, firewall rules, and SSH tunnel settings" + } else if lower.contains("name resolution") + || lower.contains("no such host") + || lower.contains("failed to lookup") + || lower.contains("dns") + { + "hint: hostname could not be resolved — check that the host address is correct" + } else { + "hint: check host, port, credentials, and SSL settings" + }; + + format!("[CONN_FAILED] {raw} ({hint})") +} + pub(crate) fn strip_trailing_statement_terminator(sql: &str) -> &str { let mut out = sql.trim_end(); while let Some(stripped) = out.strip_suffix(';') { @@ -121,7 +165,50 @@ pub async fn connect(form: &ConnectionForm) -> Result, S #[cfg(test)] mod tests { - use super::strip_trailing_statement_terminator; + use super::{conn_failed_error, strip_trailing_statement_terminator}; + + #[test] + fn conn_failed_error_tls_hint() { + let msg = conn_failed_error( + &"error communicating with database: received fatal alert: HandshakeFailure", + ); + assert!(msg.starts_with("[CONN_FAILED]")); + assert!(msg.contains("TLS/SSL handshake failed")); + assert!(!msg.contains("username/password")); + } + + #[test] + fn conn_failed_error_auth_hint() { + let msg = conn_failed_error(&"Access denied for user 'root'@'localhost'"); + assert!(msg.contains("authentication failed")); + assert!(msg.contains("URL-encoded")); + } + + #[test] + fn conn_failed_error_connection_refused_hint() { + let msg = conn_failed_error(&"Connection refused (os error 111)"); + assert!(msg.contains("could not reach the server")); + } + + #[test] + fn conn_failed_error_timeout_hint() { + let msg = conn_failed_error(&"connection timed out"); + assert!(msg.contains("could not reach the server")); + } + + #[test] + fn conn_failed_error_dns_hint() { + let msg = conn_failed_error(&"failed to lookup address information: no such host"); + assert!(msg.contains("hostname could not be resolved")); + } + + #[test] + fn conn_failed_error_generic_hint() { + let msg = conn_failed_error(&"some unknown database error"); + assert!(msg.starts_with("[CONN_FAILED]")); + assert!(msg.contains("hint:")); + assert!(!msg.contains("username/password")); + } #[test] fn strip_trailing_statement_terminator_removes_single_semicolon() { diff --git a/src-tauri/src/db/drivers/mssql.rs b/src-tauri/src/db/drivers/mssql.rs index 23811f7..70f6bef 100644 --- a/src-tauri/src/db/drivers/mssql.rs +++ b/src-tauri/src/db/drivers/mssql.rs @@ -4,6 +4,7 @@ use crate::models::{ SchemaOverview, TableDataResponse, TableInfo, TableMetadata, TableSchema, TableStructure, }; use async_trait::async_trait; +use bb8::{Pool, RunError}; use futures_util::TryStreamExt; use std::collections::{HashMap, HashSet}; use tiberius::{AuthMethod, Client, Config, EncryptionLevel, QueryItem, Row}; @@ -13,10 +14,15 @@ use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; use crate::ssh::SshTunnel; pub struct MssqlDriver { - config: MssqlConfig, + pub pool: Pool, pub ssh_tunnel: Option, } +pub struct MssqlConnectionManager { + config: MssqlConfig, +} + +#[derive(Clone)] struct MssqlConfig { host: String, port: u16, @@ -65,6 +71,13 @@ fn escape_literal(value: &str) -> String { value.replace('\'', "''") } +fn map_pool_error(err: RunError) -> String { + match err { + RunError::User(inner) => inner, + RunError::TimedOut => "[CONN_FAILED] Timed out acquiring MSSQL connection".to_string(), + } +} + fn quote_ident(ident: &str) -> Result { let trimmed = ident.trim(); if trimmed.is_empty() { @@ -128,7 +141,11 @@ fn first_sql_keyword(sql: &str) -> Option { Some(sql[start..i].to_ascii_lowercase()) } -impl MssqlDriver { +impl MssqlConnectionManager { + fn new(config: MssqlConfig) -> Self { + Self { config } + } + fn build_tiberius_config(&self, encryption: EncryptionLevel, trust_cert: bool) -> Config { let mut config = Config::new(); config.host(&self.config.host); @@ -150,36 +167,7 @@ impl MssqlDriver { config } - async fn connect_with_config(config: Config) -> Result>, String> { - let tcp = TcpStream::connect(config.get_addr()) - .await - .map_err(|e| format!("[CONN_FAILED] {}", e))?; - tcp.set_nodelay(true) - .map_err(|e| format!("[CONN_FAILED] {}", e))?; - - Client::connect(config, tcp.compat_write()) - .await - .map_err(|e| format!("[CONN_FAILED] {}", e)) - } - - pub async fn connect(form: &ConnectionForm) -> Result { - let mut cfg_form = form.clone(); - let mut ssh_tunnel = None; - - if let Some(true) = form.ssh_enabled { - let tunnel = crate::ssh::start_ssh_tunnel(form)?; - cfg_form.host = Some("127.0.0.1".to_string()); - cfg_form.port = Some(tunnel.local_port as i64); - ssh_tunnel = Some(tunnel); - } - - let config = build_config(&cfg_form)?; - let driver = Self { config, ssh_tunnel }; - driver.test_connection().await?; - Ok(driver) - } - - async fn connect_client(&self) -> Result>, String> { + async fn connect_single(&self) -> Result>, String> { let attempts = if self.config.ssl { vec![ ( @@ -221,6 +209,72 @@ impl MssqlDriver { )) } + async fn connect_with_config(config: Config) -> Result>, String> { + let connect_future = async { + let tcp = TcpStream::connect(config.get_addr()) + .await + .map_err(|e| format!("{}", e))?; + tcp.set_nodelay(true).map_err(|e| format!("{}", e))?; + Ok::(tcp) + }; + + let tcp = tokio::time::timeout(std::time::Duration::from_secs(10), connect_future) + .await + .map_err(|_| "Connection timed out".to_string())? + .map_err(|e| format!("{}", e))?; + + Client::connect(config, tcp.compat_write()) + .await + .map_err(|e| format!("{}", e)) + } +} + +#[async_trait] +impl bb8::ManageConnection for MssqlConnectionManager { + type Connection = Client>; + type Error = String; + + async fn connect(&self) -> Result { + self.connect_single().await + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + conn.simple_query("SELECT 1") + .await + .map_err(|e| format!("{}", e))?; + Ok(()) + } + + fn has_broken(&self, _conn: &mut Self::Connection) -> bool { + false + } +} + +impl MssqlDriver { + pub async fn connect(form: &ConnectionForm) -> Result { + let mut cfg_form = form.clone(); + let mut ssh_tunnel = None; + + if let Some(true) = form.ssh_enabled { + let tunnel = crate::ssh::start_ssh_tunnel(form)?; + cfg_form.host = Some("127.0.0.1".to_string()); + cfg_form.port = Some(tunnel.local_port as i64); + ssh_tunnel = Some(tunnel); + } + + let config = build_config(&cfg_form)?; + let manager = MssqlConnectionManager::new(config); + let pool = Pool::builder() + .max_size(10) + .build(manager) + .await + .map_err(|e| format!("[CONN_FAILED] Failed to create connection pool: {}", e))?; + + let driver = Self { pool, ssh_tunnel }; + driver.test_connection().await?; + Ok(driver) + } + async fn fetch_rows(&self, sql: &str) -> Result, String> { Ok(self.fetch_rows_with_columns(sql).await?.0) } @@ -229,7 +283,7 @@ impl MssqlDriver { &self, sql: &str, ) -> Result<(Vec, Vec), String> { - let mut client = self.connect_client().await?; + let mut client = self.pool.get().await.map_err(map_pool_error)?; let mut stream = client .simple_query(sql) .await @@ -438,7 +492,10 @@ mod tests { }); let hp = HashSet::from(["ID".to_string(), "amount".to_string()]); super::normalize_mssql_row_json(&mut row, &hp).unwrap(); - assert_eq!(row.get("id").and_then(|v| v.as_str()), Some("9223372036854775807")); + assert_eq!( + row.get("id").and_then(|v| v.as_str()), + Some("9223372036854775807") + ); assert_eq!(row.get("amount").and_then(|v| v.as_str()), Some("1234.56")); assert_eq!(row.get("name").and_then(|v| v.as_str()), Some("x")); } @@ -750,7 +807,9 @@ impl DatabaseDriver for MssqlDriver { qualified, where_clause, order_clause, offset, safe_limit ); let json_sql = Self::build_for_json_query(&sql); - let data = self.fetch_json_rows(&json_sql, &high_precision_cols).await?; + let data = self + .fetch_json_rows(&json_sql, &high_precision_cols) + .await?; Ok(TableDataResponse { data, @@ -801,7 +860,9 @@ impl DatabaseDriver for MssqlDriver { .map(|col| col.name.clone()) .collect(); let json_sql = Self::build_for_json_query(&sql); - let data = self.fetch_json_rows(&json_sql, &high_precision_cols).await?; + let data = self + .fetch_json_rows(&json_sql, &high_precision_cols) + .await?; return Ok(QueryResult { row_count: data.len() as i64, @@ -813,7 +874,7 @@ impl DatabaseDriver for MssqlDriver { }); } - let mut client = self.connect_client().await?; + let mut client = self.pool.get().await.map_err(map_pool_error)?; let result = client .execute(&sql, &[]) .await diff --git a/src-tauri/src/db/drivers/mysql.rs b/src-tauri/src/db/drivers/mysql.rs index 4712456..fd2ed8c 100644 --- a/src-tauri/src/db/drivers/mysql.rs +++ b/src-tauri/src/db/drivers/mysql.rs @@ -132,6 +132,10 @@ fn build_dsn_and_ca_path(form: &ConnectionForm) -> Result<(String, Option Result { Ok(build_dsn_and_ca_path(form)?.0) } +#[cfg(test)] +pub(crate) fn build_test_dsn(form: &ConnectionForm) -> Result { + build_dsn(form) +} + fn build_dsn_with_ca_path(form: &ConnectionForm) -> Result<(String, Option), String> { build_dsn_and_ca_path(form) } @@ -156,6 +165,11 @@ fn cleanup_ca_file_opt(path: Option<&PathBuf>) { } } +fn is_prepared_protocol_unsupported_error(err: &str) -> bool { + let lower = err.to_ascii_lowercase(); + lower.contains("1295") || lower.contains("prepared statement protocol") +} + impl Drop for MysqlDriver { fn drop(&mut self) { cleanup_ca_file_opt(self.ca_cert_path.as_ref()); @@ -184,11 +198,7 @@ impl MysqlDriver { .acquire_timeout(std::time::Duration::from_secs(3)) .connect(&dsn) .await - .map_err(|e| { - format!( - "[CONN_FAILED] {e} (hint: check if username/password contain special characters; they must be URL-encoded)" - ) - })?; + .map_err(|e| super::conn_failed_error(&e))?; Ok(Self { pool, @@ -383,7 +393,10 @@ fn normalize_mysql_row_json( Ok(()) } -fn decode_mysql_json_cell(row: &sqlx::mysql::MySqlRow, column_name: &str) -> Result { +fn decode_mysql_json_cell( + row: &sqlx::mysql::MySqlRow, + column_name: &str, +) -> Result { if let Ok(v) = row.try_get::, _>(column_name) { return Ok(v.0); } @@ -398,10 +411,7 @@ fn decode_mysql_json_cell(row: &sqlx::mysql::MySqlRow, column_name: &str) -> Res Err("[QUERY_ERROR] Failed to decode MySQL JSON cell".to_string()) } -fn build_mysql_json_object_expr( - columns: &[(String, String)], - table_alias: Option<&str>, -) -> String { +fn build_mysql_json_object_expr(columns: &[(String, String)], table_alias: Option<&str>) -> String { if columns.is_empty() { return "JSON_OBJECT()".to_string(); } @@ -447,6 +457,13 @@ fn is_json_projectable_statement(sql: &str) -> bool { matches!(first_sql_keyword(sql).as_deref(), Some("SELECT" | "WITH")) } +fn is_affected_rows_statement(sql: &str) -> bool { + matches!( + first_sql_keyword(sql).as_deref(), + Some("INSERT" | "UPDATE" | "DELETE" | "REPLACE") + ) +} + #[async_trait] impl DatabaseDriver for MysqlDriver { async fn close(&self) { @@ -851,11 +868,30 @@ impl DatabaseDriver for MysqlDriver { .await?; let row_count = data.len() as i64; (columns, data, row_count) - } else { - let rows = sqlx::query(&sql) - .fetch_all(&self.pool) + } else if is_affected_rows_statement(&sql) { + let result = sqlx::query(&sql) + .execute(&self.pool) .await .map_err(|e| format!("[QUERY_ERROR] {e}"))?; + (Vec::new(), Vec::new(), result.rows_affected() as i64) + } else { + let mut executed_with_raw_sql = false; + let rows = match sqlx::query(&sql).fetch_all(&self.pool).await { + Ok(rows) => rows, + Err(e) => { + let error_text = e.to_string(); + if is_prepared_protocol_unsupported_error(&error_text) { + sqlx::raw_sql(&sql) + .execute(&self.pool) + .await + .map_err(|raw_err| format!("[QUERY_ERROR] {raw_err}"))?; + executed_with_raw_sql = true; + Vec::new() + } else { + return Err(format!("[QUERY_ERROR] {e}")); + } + } + }; let columns = if let Some(first_row) = rows.first() { first_row .columns() @@ -865,6 +901,8 @@ impl DatabaseDriver for MysqlDriver { r#type: col.type_info().to_string(), }) .collect() + } else if executed_with_raw_sql { + Vec::new() } else { self.describe_query_columns(&sql).await? }; @@ -1020,7 +1058,10 @@ mod tests { }; let conn_str = build_dsn(&form).unwrap(); - assert_eq!(conn_str, "mysql://root:password@localhost:3306/test_db"); + assert_eq!( + conn_str, + "mysql://root:password@localhost:3306/test_db?ssl-mode=DISABLED" + ); } #[test] @@ -1036,7 +1077,26 @@ mod tests { }; let conn_str = build_dsn(&form).unwrap(); - assert_eq!(conn_str, "mysql://user:pass@127.0.0.1:3307"); + assert_eq!( + conn_str, + "mysql://user:pass@127.0.0.1:3307?ssl-mode=DISABLED" + ); + } + + #[test] + fn test_conn_string_allows_empty_password_when_present() { + let form = ConnectionForm { + driver: "mysql".to_string(), + host: Some("127.0.0.1".to_string()), + port: Some(3307), + username: Some("user".to_string()), + password: Some(String::new()), + database: None, + ..Default::default() + }; + + let conn_str = build_dsn(&form).unwrap(); + assert_eq!(conn_str, "mysql://user:@127.0.0.1:3307?ssl-mode=DISABLED"); } #[test] @@ -1052,7 +1112,10 @@ mod tests { }; let conn_str = build_dsn(&form).unwrap(); - assert_eq!(conn_str, "mysql://user:pass@127.0.0.1:3307"); + assert_eq!( + conn_str, + "mysql://user:pass@127.0.0.1:3307?ssl-mode=DISABLED" + ); } #[test] @@ -1068,7 +1131,10 @@ mod tests { }; let conn_str = build_dsn(&form).unwrap(); - assert_eq!(conn_str, "mysql://root:password@localhost:3308/test_db"); + assert_eq!( + conn_str, + "mysql://root:password@localhost:3308/test_db?ssl-mode=DISABLED" + ); } #[test] @@ -1086,7 +1152,35 @@ mod tests { let conn_str = build_dsn(&form).unwrap(); assert_eq!( conn_str, - "mysql://user%40name:p%40ss%3Aword%23%3F@localhost:3306/test_db" + "mysql://user%40name:p%40ss%3Aword%23%3F@localhost:3306/test_db?ssl-mode=DISABLED" + ); + } + + #[test] + fn test_conn_string_encodes_credentials_when_ssh_rewrites_target_host() { + let mut form = ConnectionForm { + driver: "mysql".to_string(), + host: Some("db.internal".to_string()), + port: Some(3306), + username: Some("user@name".to_string()), + password: Some("p#ss*@)".to_string()), + database: Some("test_db".to_string()), + ssh_enabled: Some(true), + ssh_host: Some("bastion.internal".to_string()), + ssh_port: Some(22), + ssh_username: Some("jump".to_string()), + ssh_password: Some("ssh#pass".to_string()), + ..Default::default() + }; + + // Match the production flow after the SSH tunnel assigns a local endpoint. + form.host = Some("127.0.0.1".to_string()); + form.port = Some(4406); + + let conn_str = build_dsn(&form).unwrap(); + assert_eq!( + conn_str, + "mysql://user%40name:p%23ss%2A%40%29@127.0.0.1:4406/test_db?ssl-mode=DISABLED" ); } @@ -1126,7 +1220,7 @@ mod tests { } #[test] - fn test_conn_string_with_ssl_false_does_not_explicitly_disable_tls() { + fn test_conn_string_with_ssl_false_explicitly_disables_tls() { let form = ConnectionForm { driver: "mysql".to_string(), host: Some("localhost".to_string()), @@ -1139,9 +1233,11 @@ mod tests { }; let conn_str = build_dsn(&form).unwrap(); - assert_eq!(conn_str, "mysql://root:password@localhost:3306/test_db"); - assert!(!conn_str.contains("ssl-mode=")); - assert!(!conn_str.contains("DISABLED")); + assert_eq!( + conn_str, + "mysql://root:password@localhost:3306/test_db?ssl-mode=DISABLED" + ); + assert!(conn_str.contains("ssl-mode=DISABLED")); } #[test] @@ -1194,7 +1290,9 @@ mod tests { #[test] fn test_is_json_projectable_statement() { assert!(is_json_projectable_statement("SELECT 1")); - assert!(is_json_projectable_statement(" WITH t AS (SELECT 1) SELECT * FROM t")); + assert!(is_json_projectable_statement( + " WITH t AS (SELECT 1) SELECT * FROM t" + )); assert!(!is_json_projectable_statement("SHOW TABLES")); assert!(!is_json_projectable_statement("UPDATE t SET a = 1")); } @@ -1228,7 +1326,10 @@ mod tests { normalize_mysql_row_json(&mut row, &high_precision_cols).unwrap(); - assert_eq!(row.get("id").and_then(|v| v.as_str()), Some("9223372036854775807")); + assert_eq!( + row.get("id").and_then(|v| v.as_str()), + Some("9223372036854775807") + ); assert_eq!(row.get("amount").and_then(|v| v.as_str()), Some("1234.56")); assert_eq!(row.get("name").and_then(|v| v.as_str()), Some("demo")); assert!(row.get("nullable").unwrap().is_null()); @@ -1247,4 +1348,17 @@ mod tests { assert!(sql.contains("FROM (SELECT * FROM t) AS `__dbpaw_row`")); assert!(!sql.contains(";) AS `__dbpaw_row`")); } + + #[test] + fn test_is_prepared_protocol_unsupported_error() { + assert!(is_prepared_protocol_unsupported_error( + "error returned from database: 1295 (HY000): This command is not supported in the prepared statement protocol yet" + )); + assert!(is_prepared_protocol_unsupported_error( + "prepared statement protocol is unsupported" + )); + assert!(!is_prepared_protocol_unsupported_error( + "syntax error near ...", + )); + } } diff --git a/src-tauri/src/db/drivers/postgres.rs b/src-tauri/src/db/drivers/postgres.rs index b31e055..e1d1bad 100644 --- a/src-tauri/src/db/drivers/postgres.rs +++ b/src-tauri/src/db/drivers/postgres.rs @@ -161,11 +161,7 @@ impl PostgresDriver { .acquire_timeout(std::time::Duration::from_secs(3)) .connect(&dsn) .await - .map_err(|e| { - format!( - "[CONN_FAILED] {e} (hint: check if username/password contain special characters; they must be URL-encoded)" - ) - })?; + .map_err(|e| super::conn_failed_error(&e))?; Ok(Self { pool, @@ -259,10 +255,7 @@ fn is_high_precision_pg_type(data_type: &str, udt_name: &str) -> bool { matches!( data_type.as_str(), "bigint" | "numeric" | "decimal" | "money" - ) || matches!( - udt_name.as_str(), - "int8" | "numeric" | "decimal" | "money" - ) + ) || matches!(udt_name.as_str(), "int8" | "numeric" | "decimal" | "money") } fn normalize_postgres_row_json( @@ -473,9 +466,7 @@ fn first_sql_keyword(sql: &str) -> Option { return None; } let mut end = start; - while end < bytes.len() - && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') - { + while end < bytes.len() && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') { end += 1; } if end == start { @@ -1208,6 +1199,34 @@ mod tests { ); } + #[test] + fn test_conn_string_encodes_credentials_when_ssh_rewrites_target_host() { + let mut form = ConnectionForm { + driver: "postgres".to_string(), + host: Some("db.internal".to_string()), + port: Some(5432), + username: Some("user@name".to_string()), + password: Some("p#ss*@)".to_string()), + database: Some("mydb".to_string()), + ssh_enabled: Some(true), + ssh_host: Some("bastion.internal".to_string()), + ssh_port: Some(22), + ssh_username: Some("jump".to_string()), + ssh_password: Some("ssh#pass".to_string()), + ..Default::default() + }; + + // Match the production flow after the SSH tunnel assigns a local endpoint. + form.host = Some("127.0.0.1".to_string()); + form.port = Some(55432); + + let dsn = build_dsn(&form).unwrap(); + assert_eq!( + dsn, + "postgres://user%40name:p%23ss%2A%40%29@127.0.0.1:55432/mydb" + ); + } + #[test] fn test_conn_string_missing_fields() { let form = ConnectionForm { @@ -1306,10 +1325,7 @@ mod tests { let sql = "CREATE TYPE mood_enum AS ENUM ('sad', 'ok'); CREATE TYPE address_type AS (street VARCHAR(100));"; let statements = split_sql_statements(sql); assert_eq!(statements.len(), 2); - assert_eq!( - statements[0], - "CREATE TYPE mood_enum AS ENUM ('sad', 'ok')" - ); + assert_eq!(statements[0], "CREATE TYPE mood_enum AS ENUM ('sad', 'ok')"); assert_eq!( statements[1], "CREATE TYPE address_type AS (street VARCHAR(100))" @@ -1367,7 +1383,10 @@ CREATE TABLE pg_data_type_test ( row.get("col_bigint").and_then(|v| v.as_str()), Some("9007199254740993") ); - assert_eq!(row.get("col_numeric").and_then(|v| v.as_str()), Some("1234.56")); + assert_eq!( + row.get("col_numeric").and_then(|v| v.as_str()), + Some("1234.56") + ); assert_eq!(row.get("col_text").and_then(|v| v.as_str()), Some("hello")); assert!(row.get("col_null").unwrap().is_null()); } @@ -1382,7 +1401,9 @@ CREATE TABLE pg_data_type_test ( #[test] fn test_is_json_projectable_statement() { assert!(is_json_projectable_statement("SELECT 1")); - assert!(is_json_projectable_statement(" -- a\nWITH t AS (SELECT 1) SELECT * FROM t")); + assert!(is_json_projectable_statement( + " -- a\nWITH t AS (SELECT 1) SELECT * FROM t" + )); assert!(is_json_projectable_statement("VALUES (1), (2)")); assert!(is_json_projectable_statement("TABLE my_table")); assert!(!is_json_projectable_statement("INSERT INTO t VALUES (1)")); diff --git a/src-tauri/src/db/drivers/sqlite.rs b/src-tauri/src/db/drivers/sqlite.rs index 04702fd..475fd09 100644 --- a/src-tauri/src/db/drivers/sqlite.rs +++ b/src-tauri/src/db/drivers/sqlite.rs @@ -146,9 +146,12 @@ fn sqlite_cell_to_json( let temporal_kind = sqlite_temporal_decl_kind(declared_type); let declared_bool = sqlite_declared_bool(declared_type); - let raw = row - .try_get_raw(column_name) - .map_err(|e| format!("[QUERY_ERROR] Failed to read SQLite column '{}': {}", column_name, e))?; + let raw = row.try_get_raw(column_name).map_err(|e| { + format!( + "[QUERY_ERROR] Failed to read SQLite column '{}': {}", + column_name, e + ) + })?; if raw.is_null() { return Ok(serde_json::Value::Null); } @@ -170,13 +173,11 @@ fn sqlite_cell_to_json( let maybe_date = if (-200_000..=200_000).contains(&v) { sqlite_format_date_from_days(v) } else { - sqlite_format_datetime_from_unix_seconds_f64(v as f64).and_then( - |s| { - NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S%.f") - .ok() - .map(|dt| dt.date().format("%F").to_string()) - }, - ) + sqlite_format_datetime_from_unix_seconds_f64(v as f64).and_then(|s| { + NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S%.f") + .ok() + .map(|dt| dt.date().format("%F").to_string()) + }) }; maybe_date .map(serde_json::Value::String) @@ -642,10 +643,7 @@ impl DatabaseDriver for SqliteDriver { .map(|s| s.as_str()) .or(Some(col.type_info().name())); let value = sqlite_cell_to_json(row, name, declared_type)?; - obj.insert( - name.to_string(), - value, - ); + obj.insert(name.to_string(), value); } data.push(serde_json::Value::Object(obj)); } @@ -718,10 +716,7 @@ impl DatabaseDriver for SqliteDriver { for col in row.columns() { let name = col.name(); let value = sqlite_cell_to_json(row, name, Some(col.type_info().name()))?; - obj.insert( - name.to_string(), - value, - ); + obj.insert(name.to_string(), value); } data.push(serde_json::Value::Object(obj)); } diff --git a/src-tauri/src/db/local.rs b/src-tauri/src/db/local.rs index ae6a390..5d8b7fc 100644 --- a/src-tauri/src/db/local.rs +++ b/src-tauri/src/db/local.rs @@ -24,8 +24,12 @@ impl LocalDb { .path() .app_data_dir() .map_err(|e| e.to_string())?; + Self::init_with_app_dir(&app_dir).await + } + + pub async fn init_with_app_dir(app_dir: &Path) -> Result { if !app_dir.exists() { - fs::create_dir_all(&app_dir).map_err(|e| e.to_string())?; + fs::create_dir_all(app_dir).map_err(|e| e.to_string())?; } let ai_master_key = Self::load_or_create_ai_master_key(&app_dir)?; let db_path = app_dir.join("dbpaw.sqlite"); @@ -903,6 +907,7 @@ impl LocalDb { mod tests { use super::LocalDb; use crate::models::{AiProviderForm, ConnectionForm}; + use rand::RngCore; use sqlx::sqlite::SqlitePoolOptions; async fn make_test_db() -> LocalDb { @@ -931,10 +936,10 @@ mod tests { .expect("apply migration"); } - LocalDb { - pool, - ai_master_key: [7u8; 32], - } + let mut ai_master_key = [0u8; 32]; + rand::rngs::OsRng.fill_bytes(&mut ai_master_key); + + LocalDb { pool, ai_master_key } } fn provider_form( @@ -958,7 +963,8 @@ mod tests { #[test] fn api_key_encrypt_decrypt_round_trip_and_format_validation() { - let key = [3u8; 32]; + let mut key = [0u8; 32]; + rand::rngs::OsRng.fill_bytes(&mut key); let encrypted = LocalDb::encrypt_ai_api_key_raw(&key, "secret-123").unwrap(); assert!(LocalDb::has_encrypted_ai_api_key(&encrypted)); let decrypted = LocalDb::decrypt_ai_api_key_raw(&key, &encrypted).unwrap(); diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 587e950..d4fb21f 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -142,6 +142,7 @@ pub fn run() { commands::ai::ai_delete_conversation, commands::transfer::export_table_data, commands::transfer::export_query_result, + commands::transfer::import_sql_file, ]) .build(tauri::generate_context!()) .expect("error while building tauri application"); diff --git a/src-tauri/src/models/mod.rs b/src-tauri/src/models/mod.rs index a05100d..a6af1dc 100644 --- a/src-tauri/src/models/mod.rs +++ b/src-tauri/src/models/mod.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use std::fmt; #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] #[serde(rename_all = "camelCase")] @@ -218,7 +219,7 @@ pub struct TableDataResponse { pub execution_time_ms: i64, } -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Clone, Serialize, Deserialize, Default)] #[serde(rename_all = "camelCase")] pub struct ConnectionForm { pub driver: String, // "postgres" | "mysql" | "tidb" | "mariadb" | "sqlite" | "duckdb" | "clickhouse" | "mssql" @@ -241,6 +242,36 @@ pub struct ConnectionForm { pub ssh_key_path: Option, } +impl fmt::Debug for ConnectionForm { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let username = self.username.as_ref().map(|_| ""); + let password = self.password.as_ref().map(|_| ""); + let ssl_ca_cert = self.ssl_ca_cert.as_ref().map(|_| ""); + let ssh_username = self.ssh_username.as_ref().map(|_| ""); + let ssh_password = self.ssh_password.as_ref().map(|_| ""); + f.debug_struct("ConnectionForm") + .field("driver", &self.driver) + .field("name", &self.name) + .field("host", &self.host) + .field("port", &self.port) + .field("database", &self.database) + .field("schema", &self.schema) + .field("username", &username) + .field("password", &password) + .field("ssl", &self.ssl) + .field("ssl_mode", &self.ssl_mode) + .field("ssl_ca_cert", &ssl_ca_cert) + .field("file_path", &self.file_path) + .field("ssh_enabled", &self.ssh_enabled) + .field("ssh_host", &self.ssh_host) + .field("ssh_port", &self.ssh_port) + .field("ssh_username", &ssh_username) + .field("ssh_password", &ssh_password) + .field("ssh_key_path", &self.ssh_key_path) + .finish() + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct TestConnectionResult { @@ -276,3 +307,29 @@ pub struct TableSchema { pub struct SchemaOverview { pub tables: Vec, } + +#[cfg(test)] +mod tests { + use super::ConnectionForm; + + #[test] + fn connection_form_debug_redacts_sensitive_fields() { + let form = ConnectionForm { + driver: "mysql".to_string(), + host: Some("127.0.0.1".to_string()), + username: Some("root".to_string()), + password: Some("secret".to_string()), + ssl_ca_cert: Some("cert-data".to_string()), + ssh_username: Some("jump".to_string()), + ssh_password: Some("jump-secret".to_string()), + ..Default::default() + }; + + let printed = format!("{form:?}"); + assert!(!printed.contains("root")); + assert!(!printed.contains("secret")); + assert!(!printed.contains("cert-data")); + assert!(!printed.contains("jump-secret")); + assert!(printed.contains("")); + } +} diff --git a/src-tauri/src/ssh.rs b/src-tauri/src/ssh.rs index e6b3cef..3c91e17 100644 --- a/src-tauri/src/ssh.rs +++ b/src-tauri/src/ssh.rs @@ -38,13 +38,21 @@ pub fn start_ssh_tunnel(config: &ConnectionForm) -> Result { .clone() .ok_or("SSH Username is required")?; let ssh_password = config.ssh_password.clone(); - let ssh_key_path = config - .ssh_key_path - .clone() - .and_then(|v| if v.trim().is_empty() { None } else { Some(v) }); + let ssh_key_path = + config + .ssh_key_path + .clone() + .and_then(|v| if v.trim().is_empty() { None } else { Some(v) }); let target_host = config.host.clone().unwrap_or("localhost".to_string()); - let target_port = config.port.unwrap_or(5432); + let default_port: i64 = match config.driver.to_ascii_lowercase().as_str() { + "mysql" => 3306, + "mssql" => 1433, + "clickhouse" => 9000, + "sqlite" => 0, + _ => 5432, // postgres and unknown drivers + }; + let target_port = config.port.unwrap_or(default_port); if target_port < 1 || target_port > 65535 { return Err("Target port must be between 1 and 65535".to_string()); } @@ -244,6 +252,55 @@ mod tests { use super::*; use crate::models::ConnectionForm; + #[test] + fn test_target_port_default_by_driver() { + // Verify driver-specific default ports are applied when port is None. + // We can only test port validation since start_ssh_tunnel requires a real host; + // use an out-of-range port to force early validation failure and confirm the + // default port resolution branch is NOT taken (port=None should NOT produce 5432 for MySQL). + + // For MySQL with no port set, the default must be 3306 (not 5432). + // We verify indirectly: if port is None and driver is mysql, target_port = 3306 which + // passes validation (1..=65535). The tunnel will fail to connect (no real host), but + // the validation itself won't error with "Target port must be between 1 and 65535". + let config_mysql = ConnectionForm { + driver: "mysql".to_string(), + host: Some("127.0.0.1".to_string()), + port: None, // deliberately omitted — should default to 3306 + ssh_host: Some("127.0.0.1".to_string()), + ssh_port: Some(22), + ssh_username: Some("user".to_string()), + ssh_password: Some("pass".to_string()), + ..Default::default() + }; + let result = start_ssh_tunnel(&config_mysql); + // Should fail with a network/connect error, NOT a port validation error + if let Err(e) = result { + assert!( + !e.contains("Target port must be between 1 and 65535"), + "MySQL default port (3306) should pass validation, got: {e}" + ); + } + + let config_mssql = ConnectionForm { + driver: "mssql".to_string(), + host: Some("127.0.0.1".to_string()), + port: None, // should default to 1433 + ssh_host: Some("127.0.0.1".to_string()), + ssh_port: Some(22), + ssh_username: Some("user".to_string()), + ssh_password: Some("pass".to_string()), + ..Default::default() + }; + let result = start_ssh_tunnel(&config_mssql); + if let Err(e) = result { + assert!( + !e.contains("Target port must be between 1 and 65535"), + "MSSQL default port (1433) should pass validation, got: {e}" + ); + } + } + #[test] fn test_ssh_port_validation() { let mut config = ConnectionForm::default(); diff --git a/src-tauri/tauri.conf.json b/src-tauri/tauri.conf.json index 61bb45e..9837ee5 100644 --- a/src-tauri/tauri.conf.json +++ b/src-tauri/tauri.conf.json @@ -1,7 +1,7 @@ { "$schema": "https://schema.tauri.app/config/2", "productName": "DbPaw", - "version": "0.2.9", + "version": "0.3.0", "identifier": "com.father.dbpaw", "build": { "beforeDevCommand": "bun run dev", diff --git a/src-tauri/tests/clickhouse_integration.rs b/src-tauri/tests/clickhouse_integration.rs index 2927ed8..08f4221 100644 --- a/src-tauri/tests/clickhouse_integration.rs +++ b/src-tauri/tests/clickhouse_integration.rs @@ -1,33 +1,21 @@ +#[path = "common/clickhouse_context.rs"] +mod clickhouse_context; + use dbpaw_lib::db::drivers::clickhouse::ClickHouseDriver; use dbpaw_lib::db::drivers::DatabaseDriver; -use dbpaw_lib::models::ConnectionForm; -use std::env; +use testcontainers::clients::Cli; #[tokio::test] #[ignore] async fn test_clickhouse_integration_flow() { - let host = env::var("CLICKHOUSE_HOST").unwrap_or_else(|_| "localhost".to_string()); - let port = env::var("CLICKHOUSE_PORT") - .unwrap_or_else(|_| "8123".to_string()) - .parse() - .unwrap(); - let username = env::var("CLICKHOUSE_USER").unwrap_or_else(|_| "default".to_string()); - let password = env::var("CLICKHOUSE_PASSWORD").unwrap_or_default(); - let database = env::var("CLICKHOUSE_DB").unwrap_or_else(|_| "default".to_string()); - - let form = ConnectionForm { - driver: "clickhouse".to_string(), - host: Some(host), - port: Some(port), - username: Some(username), - password: Some(password), - database: Some(database.clone()), - ..Default::default() - }; - - let driver = ClickHouseDriver::connect(&form) - .await - .expect("Failed to connect to ClickHouse"); + let docker = (!clickhouse_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = clickhouse_context::clickhouse_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("CLICKHOUSE_DB or container default database should be present"); + let driver: ClickHouseDriver = + clickhouse_context::connect_with_retry(|| ClickHouseDriver::connect(&form)).await; driver .test_connection() @@ -39,71 +27,20 @@ async fn test_clickhouse_integration_flow() { .await .expect("list_databases failed"); assert!(!databases.is_empty(), "list_databases returned empty"); - - let tables = driver - .list_tables(Some(database.clone())) - .await - .expect("list_tables failed"); - - if let Some(first_table) = tables.first() { - let _metadata = driver - .get_table_metadata(first_table.schema.clone(), first_table.name.clone()) - .await - .expect("get_table_metadata failed"); - - let _ddl = driver - .get_table_ddl(first_table.schema.clone(), first_table.name.clone()) - .await - .expect("get_table_ddl failed"); - } - - let query_result = driver - .execute_query("SELECT 1 AS ok".to_string()) - .await - .expect("execute_query failed"); - assert_eq!(query_result.row_count, 1); - - let overview = driver - .get_schema_overview(Some(database)) - .await - .expect("get_schema_overview failed"); assert!( - !overview.tables.is_empty() || tables.is_empty(), - "schema overview expected to have tables when list_tables has entries" + databases.iter().any(|db| db == &database), + "list_databases should include {}", + database ); - driver.close().await; -} - -#[tokio::test] -#[ignore] -async fn test_clickhouse_type_mapping_and_metadata_flow() { - let host = env::var("CLICKHOUSE_HOST").unwrap_or_else(|_| "localhost".to_string()); - let port = env::var("CLICKHOUSE_PORT") - .unwrap_or_else(|_| "8123".to_string()) - .parse() - .unwrap(); - let username = env::var("CLICKHOUSE_USER").unwrap_or_else(|_| "default".to_string()); - let password = env::var("CLICKHOUSE_PASSWORD").unwrap_or_default(); - let database = env::var("CLICKHOUSE_DB").unwrap_or_else(|_| "default".to_string()); - - let form = ConnectionForm { - driver: "clickhouse".to_string(), - host: Some(host), - port: Some(port), - username: Some(username), - password: Some(password), - database: Some(database.clone()), - ..Default::default() - }; - - let driver = ClickHouseDriver::connect(&form) - .await - .expect("Failed to connect to ClickHouse"); - - let table_name = "dbpaw_ch_type_probe"; + let table_name = "dbpaw_clickhouse_type_probe"; + let view_name = "dbpaw_clickhouse_type_probe_v"; let qualified = format!("`{}`.`{}`", database, table_name); + let qualified_view = format!("`{}`.`{}`", database, view_name); + let _ = driver + .execute_query(format!("DROP VIEW IF EXISTS {}", qualified_view)) + .await; let _ = driver .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) .await; @@ -112,9 +49,10 @@ async fn test_clickhouse_type_mapping_and_metadata_flow() { .execute_query(format!( "CREATE TABLE {} (\ id UInt32, \ + name String, \ amount Decimal(10,2), \ - created_at DateTime, \ - note Nullable(String)\ + payload String, \ + created_at DateTime\ ) ENGINE = MergeTree ORDER BY id", qualified )) @@ -123,41 +61,61 @@ async fn test_clickhouse_type_mapping_and_metadata_flow() { driver .execute_query(format!( - "INSERT INTO {} (id, amount, created_at, note) VALUES \ - (1, 12.34, toDateTime('2026-01-02 03:04:05'), NULL)", - qualified + "CREATE VIEW {} AS SELECT id, name FROM {}", + qualified_view, qualified )) .await - .expect("insert probe row failed"); + .expect("create view failed"); - // 1) list_databases/list_tables - let databases = driver - .list_databases() + driver + .execute_query(format!( + "INSERT INTO {} (id, name, amount, payload, created_at) VALUES \ + (1, 'hello', 12.34, 'DEADBEEF', toDateTime('2026-01-02 03:04:05'))", + qualified + )) .await - .expect("list_databases failed"); - assert!( - databases.iter().any(|d| d == &database), - "list_databases should include active database {}", - database - ); + .expect("insert failed"); let tables = driver .list_tables(Some(database.clone())) .await .expect("list_tables failed"); assert!( - tables - .iter() - .any(|t| t.schema == database && t.name == table_name), - "list_tables should include {}.{}", - database, + tables.iter().any(|t| t.name == table_name), + "list_tables should include {}", table_name ); + assert!( + tables.iter().any(|t| t.name == view_name), + "list_tables should include {}", + view_name + ); + + let metadata = driver + .get_table_metadata(database.clone(), table_name.to_string()) + .await + .expect("get_table_metadata failed"); + assert!( + metadata.columns.iter().any(|c| c.name == "id" && c.primary_key), + "metadata should include primary key id" + ); + assert!( + metadata.columns.iter().any(|c| c.name == "payload"), + "metadata should include payload column" + ); + + let ddl = driver + .get_table_ddl(database.clone(), table_name.to_string()) + .await + .expect("get_table_ddl failed"); + assert!( + ddl.to_uppercase().contains("CREATE TABLE"), + "DDL should contain CREATE TABLE" + ); - // 2) execute_query type mapping (Decimal/DateTime/Nullable) let result = driver .execute_query(format!( - "SELECT amount, created_at, note FROM {} WHERE id = 1", + "SELECT id, name, amount, payload, created_at FROM {} WHERE id = 1", qualified )) .await @@ -167,19 +125,211 @@ async fn test_clickhouse_type_mapping_and_metadata_flow() { .data .first() .expect("typed result should include at least one row"); - + let id_value = row.get("id").expect("id should exist"); + assert!( + id_value == &serde_json::Value::String("1".to_string()) + || id_value == &serde_json::Value::Number(serde_json::Number::from(1)), + "unexpected id value: {:?}", + id_value + ); + assert_eq!(row["name"], serde_json::Value::String("hello".to_string())); assert!(row.get("amount").is_some(), "amount should exist"); + assert!(row.get("payload").is_some(), "payload should exist"); + + let _ = driver + .execute_query(format!("DROP VIEW IF EXISTS {}", qualified_view)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_clickhouse_get_table_data_supports_pagination_sort_filter_and_order_by() { + let docker = (!clickhouse_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = clickhouse_context::clickhouse_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("CLICKHOUSE_DB or container default database should be present"); + let driver: ClickHouseDriver = + clickhouse_context::connect_with_retry(|| ClickHouseDriver::connect(&form)).await; + + let table_name = "dbpaw_clickhouse_grid_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + + driver + .execute_query(format!( + "CREATE TABLE {} (id UInt32, name String, score Int32) \ + ENGINE = MergeTree ORDER BY id", + qualified + )) + .await + .expect("create dbpaw_clickhouse_grid_probe failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name, score) VALUES \ + (1, 'alpha', 10), (2, 'beta', 20), (3, 'gamma', 30), (4, 'delta', 40)", + qualified + )) + .await + .expect("insert dbpaw_clickhouse_grid_probe failed"); + + let page1 = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 2, + Some("score".to_string()), + Some("desc".to_string()), + None, + None, + ) + .await + .expect("get_table_data page1 failed"); + assert_eq!(page1.total, 4); + assert_eq!(page1.data.len(), 2); + assert_eq!( + page1.data[0]["name"], + serde_json::Value::String("delta".to_string()) + ); + + let filtered = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 10, + None, + None, + Some("score >= 20".to_string()), + None, + ) + .await + .expect("get_table_data with filter failed"); + assert_eq!(filtered.total, 3); + + let ordered = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 1, + Some("id".to_string()), + Some("asc".to_string()), + None, + Some("name DESC".to_string()), + ) + .await + .expect("get_table_data with order_by failed"); + assert_eq!(ordered.total, 4); + assert_eq!(ordered.data.len(), 1); + assert_eq!( + ordered.data[0]["name"], + serde_json::Value::String("gamma".to_string()) + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_clickhouse_get_table_data_rejects_invalid_sort_column() { + let docker = (!clickhouse_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = clickhouse_context::clickhouse_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("CLICKHOUSE_DB or container default database should be present"); + let driver: ClickHouseDriver = + clickhouse_context::connect_with_retry(|| ClickHouseDriver::connect(&form)).await; + + let table_name = "dbpaw_clickhouse_invalid_sort_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id UInt32) ENGINE = MergeTree ORDER BY id", + qualified + )) + .await + .expect("create dbpaw_clickhouse_invalid_sort_probe failed"); + + let result = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 10, + Some("id desc".to_string()), + Some("desc".to_string()), + None, + None, + ) + .await; + let err = result.expect_err("invalid sort column should return error"); assert!( - row["amount"].is_number() || row["amount"].is_string(), - "Decimal should be represented as number or string in JSON" + err.contains("[VALIDATION_ERROR] Invalid sort column name"), + "unexpected error: {}", + err ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_clickhouse_table_structure_and_schema_overview() { + let docker = (!clickhouse_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = clickhouse_context::clickhouse_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("CLICKHOUSE_DB or container default database should be present"); + let driver: ClickHouseDriver = + clickhouse_context::connect_with_retry(|| ClickHouseDriver::connect(&form)).await; + + let table_name = "dbpaw_clickhouse_overview_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + + driver + .execute_query(format!( + "CREATE TABLE {} (id UInt32, label String) ENGINE = MergeTree ORDER BY id", + qualified + )) + .await + .expect("create dbpaw_clickhouse_overview_probe failed"); + + let structure = driver + .get_table_structure(database.clone(), table_name.to_string()) + .await + .expect("get_table_structure failed"); assert!( - row["created_at"].is_string(), - "DateTime should be represented as string in JSON" + structure.columns.iter().any(|c| c.name == "id"), + "table structure should include id" + ); + assert!( + structure.columns.iter().any(|c| c.name == "label"), + "table structure should include label" ); - assert!(row["note"].is_null(), "Nullable(String) should decode NULL"); - // 3) schema overview + DDL let overview = driver .get_schema_overview(Some(database.clone())) .await @@ -194,14 +344,152 @@ async fn test_clickhouse_type_mapping_and_metadata_flow() { table_name ); - let ddl = driver - .get_table_ddl(database.clone(), table_name.to_string()) + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_clickhouse_metadata_includes_engine_extra() { + let docker = (!clickhouse_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = clickhouse_context::clickhouse_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("CLICKHOUSE_DB or container default database should be present"); + let driver: ClickHouseDriver = + clickhouse_context::connect_with_retry(|| ClickHouseDriver::connect(&form)).await; + + let table_name = "dbpaw_clickhouse_meta_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id UInt32, name String) ENGINE = MergeTree ORDER BY id", + qualified + )) .await - .expect("get_table_ddl failed"); + .expect("create dbpaw_clickhouse_meta_probe failed"); + + let metadata = driver + .get_table_metadata(database.clone(), table_name.to_string()) + .await + .expect("get_table_metadata failed"); + assert!( + metadata.columns.iter().any(|c| c.name == "id"), + "metadata should include id column" + ); + assert!( + metadata.clickhouse_extra.is_some(), + "metadata should include clickhouse engine extra" + ); + assert!( + metadata + .clickhouse_extra + .as_ref() + .map(|extra| extra.engine.contains("MergeTree")) + .unwrap_or(false), + "clickhouse extra engine should include MergeTree" + ); + assert!( + metadata.indexes.is_empty(), + "clickhouse metadata indexes should be empty for now" + ); + assert!( + metadata.foreign_keys.is_empty(), + "clickhouse metadata foreign_keys should be empty for now" + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_clickhouse_boolean_and_json_type_mapping_regression() { + let docker = (!clickhouse_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = clickhouse_context::clickhouse_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("CLICKHOUSE_DB or container default database should be present"); + let driver: ClickHouseDriver = + clickhouse_context::connect_with_retry(|| ClickHouseDriver::connect(&form)).await; + + let table_name = "dbpaw_clickhouse_bool_json_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id UInt32, flag UInt8, meta String) \ + ENGINE = MergeTree ORDER BY id", + qualified + )) + .await + .expect("create bool/json probe table failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, flag, meta) VALUES (1, 1, '{{\"tier\":\"gold\"}}')", + qualified + )) + .await + .expect("insert bool/json probe row failed"); + + let query_result = driver + .execute_query(format!( + "SELECT flag, JSONExtractString(meta, 'tier') AS tier FROM {} WHERE id = 1", + qualified + )) + .await + .expect("select bool/json row failed"); + assert_eq!(query_result.row_count, 1); + let query_row = query_result.data.first().expect("query row should exist"); + let query_flag = query_row + .get("flag") + .expect("flag should exist in query result"); + assert!( + query_flag == &serde_json::Value::Bool(true) + || query_flag == &serde_json::Value::Number(serde_json::Number::from(1)) + || query_flag == &serde_json::Value::String("1".to_string()), + "unexpected query flag value: {:?}", + query_flag + ); + assert_eq!(query_row["tier"], serde_json::Value::String("gold".to_string())); + + let table_data = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 10, + None, + None, + None, + None, + ) + .await + .expect("get_table_data for bool/json probe failed"); + assert_eq!(table_data.total, 1); + let grid_row = table_data.data.first().expect("table row should exist"); + let grid_flag = grid_row + .get("flag") + .expect("flag should exist in table_data result"); assert!( - ddl.contains(table_name) && ddl.to_uppercase().contains("CREATE TABLE"), - "DDL should contain CREATE TABLE and table name" + grid_flag == &serde_json::Value::Bool(true) + || grid_flag == &serde_json::Value::Number(serde_json::Number::from(1)) + || grid_flag == &serde_json::Value::String("1".to_string()), + "unexpected grid flag value: {:?}", + grid_flag ); + assert!(grid_row.get("meta").is_some(), "meta should exist in table_data"); let _ = driver .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) diff --git a/src-tauri/tests/common/clickhouse_context.rs b/src-tauri/tests/common/clickhouse_context.rs new file mode 100644 index 0000000..8e540c2 --- /dev/null +++ b/src-tauri/tests/common/clickhouse_context.rs @@ -0,0 +1,80 @@ +mod shared; + +use dbpaw_lib::models::ConnectionForm; +use std::env; +use std::time::Duration; +use testcontainers::clients::Cli; +use testcontainers::core::WaitFor; +use testcontainers::{Container, GenericImage, RunnableImage}; + +pub use shared::{connect_with_retry, should_reuse_local_db}; + +pub fn clickhouse_form_from_test_context<'a>( + docker: Option<&'a Cli>, +) -> (Option>, ConnectionForm) { + if should_reuse_local_db() { + return (None, clickhouse_form_from_local_env()); + } + shared::ensure_docker_available(); + + let docker = docker.expect("docker client is required when IT_REUSE_LOCAL_DB is not enabled"); + let image = GenericImage::new("clickhouse/clickhouse-server", "24.3") + .with_env_var("CLICKHOUSE_USER", "dbpaw") + .with_env_var("CLICKHOUSE_PASSWORD", "123456") + .with_env_var("CLICKHOUSE_DB", "test_db") + .with_wait_for(WaitFor::seconds(8)) + .with_exposed_port(8123); + let runnable = + RunnableImage::from(image).with_container_name(shared::unique_container_name("clickhouse")); + let container = docker.run(runnable); + let port = container.get_host_port_ipv4(8123); + + shared::wait_for_port("127.0.0.1", port, Duration::from_secs(60)); + + let mut form = ConnectionForm { + driver: "clickhouse".to_string(), + host: Some("127.0.0.1".to_string()), + port: Some(i64::from(port)), + username: Some("dbpaw".to_string()), + password: Some("123456".to_string()), + database: Some("test_db".to_string()), + ..Default::default() + }; + apply_clickhouse_env_overrides(&mut form); + (Some(container), form) +} + +fn clickhouse_form_from_local_env() -> ConnectionForm { + let mut form = ConnectionForm { + driver: "clickhouse".to_string(), + host: Some(shared::env_or("CLICKHOUSE_HOST", "localhost")), + port: Some(shared::env_i64("CLICKHOUSE_PORT", 8123)), + username: Some(shared::env_or("CLICKHOUSE_USER", "default")), + password: Some(shared::env_or("CLICKHOUSE_PASSWORD", "")), + database: Some(shared::env_or("CLICKHOUSE_DB", "default")), + ..Default::default() + }; + apply_clickhouse_env_overrides(&mut form); + form +} + +fn apply_clickhouse_env_overrides(form: &mut ConnectionForm) { + if let Ok(host) = env::var("CLICKHOUSE_HOST") { + form.host = Some(host); + } + if let Ok(port) = env::var("CLICKHOUSE_PORT") { + form.port = Some( + port.parse::() + .expect("CLICKHOUSE_PORT should be a valid number"), + ); + } + if let Ok(user) = env::var("CLICKHOUSE_USER") { + form.username = Some(user); + } + if let Ok(password) = env::var("CLICKHOUSE_PASSWORD") { + form.password = Some(password); + } + if let Ok(database) = env::var("CLICKHOUSE_DB") { + form.database = Some(database); + } +} diff --git a/src-tauri/tests/common/duckdb_context.rs b/src-tauri/tests/common/duckdb_context.rs new file mode 100644 index 0000000..1bebb18 --- /dev/null +++ b/src-tauri/tests/common/duckdb_context.rs @@ -0,0 +1,40 @@ +use dbpaw_lib::models::ConnectionForm; +use std::env; +use std::path::PathBuf; +use uuid::Uuid; + +pub fn should_reuse_local_db() -> bool { + env::var("IT_REUSE_LOCAL_DB") + .map(|value| value == "1" || value.eq_ignore_ascii_case("true")) + .unwrap_or(false) +} + +pub fn duckdb_form_from_test_context() -> (PathBuf, ConnectionForm) { + let path = duckdb_test_path_from_context(); + let form = ConnectionForm { + driver: "duckdb".to_string(), + file_path: Some(path.to_string_lossy().to_string()), + ..Default::default() + }; + (path, form) +} + +fn duckdb_test_path_from_context() -> PathBuf { + if let Ok(v) = env::var("DUCKDB_IT_DB_PATH") { + return PathBuf::from(v); + } + if let Ok(v) = env::var("DUCKDB_DB_PATH") { + return PathBuf::from(v); + } + if should_reuse_local_db() { + let mut p = env::temp_dir(); + p.push("dbpaw-duckdb-local-it.duckdb"); + return p; + } + let mut p = env::temp_dir(); + p.push(format!( + "dbpaw-duckdb-integration-{}.duckdb", + Uuid::new_v4() + )); + p +} diff --git a/src-tauri/tests/common/mariadb_context.rs b/src-tauri/tests/common/mariadb_context.rs new file mode 100644 index 0000000..d81e5b2 --- /dev/null +++ b/src-tauri/tests/common/mariadb_context.rs @@ -0,0 +1,79 @@ +mod shared; + +use dbpaw_lib::models::ConnectionForm; +use std::env; +use std::time::Duration; +use testcontainers::clients::Cli; +use testcontainers::core::WaitFor; +use testcontainers::{Container, GenericImage, RunnableImage}; + +pub use shared::{connect_with_retry, should_reuse_local_db}; + +pub fn mariadb_form_from_test_context<'a>( + docker: Option<&'a Cli>, +) -> (Option>, ConnectionForm) { + if should_reuse_local_db() { + return (None, mariadb_form_from_local_env()); + } + shared::ensure_docker_available(); + + let docker = docker.expect("docker client is required when IT_REUSE_LOCAL_DB is not enabled"); + let image = GenericImage::new("mariadb", "11") + .with_env_var("MARIADB_ROOT_PASSWORD", "123456") + .with_env_var("MARIADB_DATABASE", "test_db") + .with_wait_for(WaitFor::seconds(5)) + .with_exposed_port(3306); + let runnable = + RunnableImage::from(image).with_container_name(shared::unique_container_name("mariadb")); + let container = docker.run(runnable); + let port = container.get_host_port_ipv4(3306); + + shared::wait_for_port("127.0.0.1", port, Duration::from_secs(45)); + + let mut form = ConnectionForm { + driver: "mariadb".to_string(), + host: Some("127.0.0.1".to_string()), + port: Some(i64::from(port)), + username: Some("root".to_string()), + password: Some("123456".to_string()), + database: Some("test_db".to_string()), + ..Default::default() + }; + apply_mariadb_env_overrides(&mut form); + (Some(container), form) +} + +fn mariadb_form_from_local_env() -> ConnectionForm { + let mut form = ConnectionForm { + driver: "mariadb".to_string(), + host: Some(shared::env_or("MARIADB_HOST", "localhost")), + port: Some(shared::env_i64("MARIADB_PORT", 3306)), + username: Some(shared::env_or("MARIADB_USER", "root")), + password: Some(shared::env_or("MARIADB_PASSWORD", "123456")), + database: Some(shared::env_or("MARIADB_DB", "test_db")), + ..Default::default() + }; + apply_mariadb_env_overrides(&mut form); + form +} + +fn apply_mariadb_env_overrides(form: &mut ConnectionForm) { + if let Ok(host) = env::var("MARIADB_HOST") { + form.host = Some(host); + } + if let Ok(port) = env::var("MARIADB_PORT") { + form.port = Some( + port.parse::() + .expect("MARIADB_PORT should be a valid number"), + ); + } + if let Ok(user) = env::var("MARIADB_USER") { + form.username = Some(user); + } + if let Ok(password) = env::var("MARIADB_PASSWORD") { + form.password = Some(password); + } + if let Ok(database) = env::var("MARIADB_DB") { + form.database = Some(database); + } +} diff --git a/src-tauri/tests/common/mssql_context.rs b/src-tauri/tests/common/mssql_context.rs new file mode 100644 index 0000000..79eab54 --- /dev/null +++ b/src-tauri/tests/common/mssql_context.rs @@ -0,0 +1,118 @@ +mod shared; + +use dbpaw_lib::models::ConnectionForm; +use std::env; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Mutex; +use std::sync::OnceLock; +use std::time::Duration; +use testcontainers::clients::Cli; +use testcontainers::core::WaitFor; +use testcontainers::{Container, GenericImage, RunnableImage}; + +pub use shared::{connect_with_retry, should_reuse_local_db}; + +struct SharedMssql { + container: Mutex>>, + form: ConnectionForm, + ref_count: AtomicUsize, +} + +pub struct MssqlContainerGuard { + shared: &'static SharedMssql, +} + +impl Drop for MssqlContainerGuard { + fn drop(&mut self) { + if self.shared.ref_count.fetch_sub(1, Ordering::AcqRel) == 1 { + let mut container = self + .shared + .container + .lock() + .expect("mssql container mutex poisoned"); + let _ = container.take(); + } + } +} + +static MSSQL_SHARED: OnceLock = OnceLock::new(); + +pub fn mssql_form_from_test_context( + docker: Option<&Cli>, +) -> (Option, ConnectionForm) { + let _ = docker; + if should_reuse_local_db() { + return (None, mssql_form_from_local_env()); + } + shared::ensure_docker_available(); + + let shared = MSSQL_SHARED.get_or_init(|| { + let docker = Box::leak(Box::new(Cli::default())); + let image = GenericImage::new("mcr.microsoft.com/mssql/server", "2022-latest") + .with_env_var("ACCEPT_EULA", "Y") + .with_env_var("MSSQL_PID", "Developer") + .with_env_var("MSSQL_SA_PASSWORD", "YourStrong!Passw0rd") + .with_wait_for(WaitFor::seconds(20)) + .with_exposed_port(1433); + let runnable = + RunnableImage::from(image).with_container_name(shared::unique_container_name("mssql")); + let container = docker.run(runnable); + let port = container.get_host_port_ipv4(1433); + + shared::wait_for_port("127.0.0.1", port, Duration::from_secs(90)); + + let mut form = ConnectionForm { + driver: "mssql".to_string(), + host: Some("127.0.0.1".to_string()), + port: Some(i64::from(port)), + username: Some("sa".to_string()), + password: Some("YourStrong!Passw0rd".to_string()), + database: Some("master".to_string()), + ..Default::default() + }; + apply_mssql_env_overrides(&mut form); + SharedMssql { + container: Mutex::new(Some(container)), + form, + ref_count: AtomicUsize::new(0), + } + }); + shared.ref_count.fetch_add(1, Ordering::AcqRel); + + (Some(MssqlContainerGuard { shared }), shared.form.clone()) +} + +fn mssql_form_from_local_env() -> ConnectionForm { + let mut form = ConnectionForm { + driver: "mssql".to_string(), + host: Some(shared::env_or("MSSQL_HOST", "localhost")), + port: Some(shared::env_i64("MSSQL_PORT", 1433)), + username: Some(shared::env_or("MSSQL_USER", "sa")), + password: Some(shared::env_or("MSSQL_PASSWORD", "")), + database: Some(shared::env_or("MSSQL_DB", "master")), + ..Default::default() + }; + apply_mssql_env_overrides(&mut form); + form +} + +fn apply_mssql_env_overrides(form: &mut ConnectionForm) { + if let Ok(host) = env::var("MSSQL_HOST") { + form.host = Some(host); + } + if let Ok(port) = env::var("MSSQL_PORT") { + form.port = Some( + port.parse::() + .expect("MSSQL_PORT should be a valid number"), + ); + } + if let Ok(user) = env::var("MSSQL_USER") { + form.username = Some(user); + } + if let Ok(password) = env::var("MSSQL_PASSWORD") { + form.password = Some(password); + } + if let Ok(database) = env::var("MSSQL_DB") { + form.database = Some(database); + } +} diff --git a/src-tauri/tests/common/mysql_context.rs b/src-tauri/tests/common/mysql_context.rs new file mode 100644 index 0000000..2707f70 --- /dev/null +++ b/src-tauri/tests/common/mysql_context.rs @@ -0,0 +1,80 @@ +mod shared; + +use dbpaw_lib::models::ConnectionForm; +use std::env; +use std::time::Duration; +use testcontainers::clients::Cli; +use testcontainers::core::WaitFor; +use testcontainers::{Container, GenericImage, RunnableImage}; + +pub use shared::{connect_with_retry, should_reuse_local_db}; + +pub fn mysql_form_from_test_context<'a>( + docker: Option<&'a Cli>, +) -> (Option>, ConnectionForm) { + if should_reuse_local_db() { + return (None, mysql_form_from_local_env()); + } + shared::ensure_docker_available(); + + let docker = docker.expect("docker client is required when IT_REUSE_LOCAL_DB is not enabled"); + let image = GenericImage::new("mysql", "8.0") + .with_env_var("MYSQL_ROOT_PASSWORD", "123456") + .with_env_var("MYSQL_ROOT_HOST", "%") + .with_env_var("MYSQL_DATABASE", "test_db") + .with_wait_for(WaitFor::seconds(5)) + .with_exposed_port(3306); + let runnable = + RunnableImage::from(image).with_container_name(shared::unique_container_name("mysql")); + let container = docker.run(runnable); + let port = container.get_host_port_ipv4(3306); + + shared::wait_for_port("127.0.0.1", port, Duration::from_secs(45)); + + let mut form = ConnectionForm { + driver: "mysql".to_string(), + host: Some("127.0.0.1".to_string()), + port: Some(i64::from(port)), + username: Some("root".to_string()), + password: Some("123456".to_string()), + database: Some("test_db".to_string()), + ..Default::default() + }; + apply_mysql_env_overrides(&mut form); + (Some(container), form) +} + +fn mysql_form_from_local_env() -> ConnectionForm { + let mut form = ConnectionForm { + driver: "mysql".to_string(), + host: Some(shared::env_or("MYSQL_HOST", "localhost")), + port: Some(shared::env_i64("MYSQL_PORT", 3306)), + username: Some(shared::env_or("MYSQL_USER", "root")), + password: Some(shared::env_or("MYSQL_PASSWORD", "123456")), + database: Some(shared::env_or("MYSQL_DB", "test_db")), + ..Default::default() + }; + apply_mysql_env_overrides(&mut form); + form +} + +fn apply_mysql_env_overrides(form: &mut ConnectionForm) { + if let Ok(host) = env::var("MYSQL_HOST") { + form.host = Some(host); + } + if let Ok(port) = env::var("MYSQL_PORT") { + form.port = Some( + port.parse::() + .expect("MYSQL_PORT should be a valid number"), + ); + } + if let Ok(user) = env::var("MYSQL_USER") { + form.username = Some(user); + } + if let Ok(password) = env::var("MYSQL_PASSWORD") { + form.password = Some(password); + } + if let Ok(database) = env::var("MYSQL_DB") { + form.database = Some(database); + } +} diff --git a/src-tauri/tests/common/postgres_context.rs b/src-tauri/tests/common/postgres_context.rs new file mode 100644 index 0000000..5a14090 --- /dev/null +++ b/src-tauri/tests/common/postgres_context.rs @@ -0,0 +1,79 @@ +mod shared; + +use dbpaw_lib::models::ConnectionForm; +use std::time::Duration; +use testcontainers::clients::Cli; +use testcontainers::core::WaitFor; +use testcontainers::{Container, GenericImage, RunnableImage}; + +pub use shared::{connect_with_retry, should_reuse_local_db}; + +pub fn postgres_form_from_test_context<'a>( + docker: Option<&'a Cli>, +) -> (Option>, ConnectionForm) { + if should_reuse_local_db() { + return (None, postgres_form_from_local_env()); + } + shared::ensure_docker_available(); + + let docker = docker.expect("docker client is required when IT_REUSE_LOCAL_DB is not enabled"); + let image = GenericImage::new("postgres", "16-alpine") + .with_env_var("POSTGRES_USER", "postgres") + .with_env_var("POSTGRES_PASSWORD", "postgres") + .with_env_var("POSTGRES_DB", "postgres") + .with_wait_for(WaitFor::seconds(3)) + .with_exposed_port(5432); + let runnable = RunnableImage::from(image) + .with_container_name(shared::unique_container_name("postgres")); + let container = docker.run(runnable); + let port = container.get_host_port_ipv4(5432); + + shared::wait_for_port("127.0.0.1", port, Duration::from_secs(45)); + + let mut form = ConnectionForm { + driver: "postgres".to_string(), + host: Some("127.0.0.1".to_string()), + port: Some(i64::from(port)), + username: Some("postgres".to_string()), + password: Some("postgres".to_string()), + database: Some("postgres".to_string()), + ..Default::default() + }; + apply_postgres_env_overrides(&mut form); + (Some(container), form) +} + +fn postgres_form_from_local_env() -> ConnectionForm { + let mut form = ConnectionForm { + driver: "postgres".to_string(), + host: Some(shared::env_or_any(&["POSTGRES_HOST", "PG_HOST"], "localhost")), + port: Some(shared::env_i64_any(&["POSTGRES_PORT", "PG_PORT"], 5432)), + username: Some(shared::env_or_any(&["POSTGRES_USER", "PGUSER"], "postgres")), + password: Some(shared::env_or_any(&["POSTGRES_PASSWORD", "PGPASSWORD"], "postgres")), + database: Some(shared::env_or_any(&["POSTGRES_DB", "PGDATABASE"], "postgres")), + ..Default::default() + }; + apply_postgres_env_overrides(&mut form); + form +} + +fn apply_postgres_env_overrides(form: &mut ConnectionForm) { + if let Some(host) = shared::env_any(&["POSTGRES_HOST", "PG_HOST"]) { + form.host = Some(host); + } + if let Some(port) = shared::env_any(&["POSTGRES_PORT", "PG_PORT"]) { + form.port = Some( + port.parse::() + .expect("POSTGRES_PORT/PG_PORT should be a valid number"), + ); + } + if let Some(user) = shared::env_any(&["POSTGRES_USER", "PGUSER"]) { + form.username = Some(user); + } + if let Some(password) = shared::env_any(&["POSTGRES_PASSWORD", "PGPASSWORD"]) { + form.password = Some(password); + } + if let Some(database) = shared::env_any(&["POSTGRES_DB", "PGDATABASE"]) { + form.database = Some(database); + } +} diff --git a/src-tauri/tests/common/shared.rs b/src-tauri/tests/common/shared.rs new file mode 100644 index 0000000..a4300fd --- /dev/null +++ b/src-tauri/tests/common/shared.rs @@ -0,0 +1,112 @@ +use std::env; +use std::future::Future; +use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; +use std::process::Command; +use std::thread::sleep; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +const CONNECT_RETRY_ATTEMPTS: usize = 20; +const CONNECT_RETRY_DELAY_MS: u64 = 500; + +pub fn should_reuse_local_db() -> bool { + env::var("IT_REUSE_LOCAL_DB") + .map(|value| value == "1" || value.eq_ignore_ascii_case("true")) + .unwrap_or(false) +} + +pub fn wait_for_port(host: &str, port: u16, timeout: Duration) { + let started_at = Instant::now(); + let addr = resolve_socket_addr(host, port); + + while started_at.elapsed() < timeout { + if TcpStream::connect_timeout(&addr, Duration::from_millis(500)).is_ok() { + return; + } + sleep(Duration::from_millis(500)); + } + + panic!("timed out waiting for {}:{} to accept connections", host, port); +} + +pub fn ensure_docker_available() { + let output = Command::new("docker").arg("info").output().unwrap_or_else(|error| { + panic!( + "failed to run `docker info`: {}. Install/start Docker, or run with IT_REUSE_LOCAL_DB=1", + error + ) + }); + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!( + "`docker info` failed: {}. Start Docker daemon, or run with IT_REUSE_LOCAL_DB=1", + stderr.trim() + ); + } +} + +pub fn unique_container_name(kind: &str) -> String { + let prefix = env::var("IT_CONTAINER_PREFIX").unwrap_or_else(|_| "dbpaw-it-".to_string()); + let pid = std::process::id(); + let ts = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_else(|_| Duration::from_secs(0)) + .as_millis(); + format!("{prefix}{kind}-{pid}-{ts}") +} + +#[allow(dead_code)] +pub fn env_or(name: &str, default: &str) -> String { + env::var(name).unwrap_or_else(|_| default.to_string()) +} + +#[allow(dead_code)] +pub fn env_or_any(names: &[&str], default: &str) -> String { + env_any(names).unwrap_or_else(|| default.to_string()) +} + +#[allow(dead_code)] +pub fn env_any(names: &[&str]) -> Option { + names.iter().find_map(|name| env::var(name).ok()) +} + +#[allow(dead_code)] +pub fn env_i64(name: &str, default: i64) -> i64 { + env::var(name) + .ok() + .and_then(|value| value.parse::().ok()) + .unwrap_or(default) +} + +#[allow(dead_code)] +pub fn env_i64_any(names: &[&str], default: i64) -> i64 { + env_any(names) + .and_then(|value| value.parse::().ok()) + .unwrap_or(default) +} + +pub async fn connect_with_retry(mut connect: F) -> T +where + F: FnMut() -> Fut, + Fut: Future>, +{ + let mut last_error = String::new(); + for _ in 0..CONNECT_RETRY_ATTEMPTS { + match connect().await { + Ok(value) => return value, + Err(err) => { + last_error = err; + tokio::time::sleep(Duration::from_millis(CONNECT_RETRY_DELAY_MS)).await; + } + } + } + panic!("Failed to connect after retries: {last_error}"); +} + +fn resolve_socket_addr(host: &str, port: u16) -> SocketAddr { + (host, port) + .to_socket_addrs() + .expect("failed to resolve socket address") + .next() + .expect("resolved zero socket addresses") +} diff --git a/src-tauri/tests/duckdb_integration.rs b/src-tauri/tests/duckdb_integration.rs index de3d4e6..4ff7824 100644 --- a/src-tauri/tests/duckdb_integration.rs +++ b/src-tauri/tests/duckdb_integration.rs @@ -1,38 +1,25 @@ +#[path = "common/duckdb_context.rs"] +mod duckdb_context; + use dbpaw_lib::db::drivers::duckdb::DuckdbDriver; use dbpaw_lib::db::drivers::DatabaseDriver; -use dbpaw_lib::models::ConnectionForm; -use std::env; -use std::path::PathBuf; -use uuid::Uuid; - -fn duckdb_test_path() -> PathBuf { - if let Ok(v) = env::var("DUCKDB_IT_DB_PATH") { - return PathBuf::from(v); - } - let mut p = env::temp_dir(); - p.push(format!( - "dbpaw-duckdb-integration-{}.duckdb", - Uuid::new_v4() - )); - p -} #[tokio::test] #[ignore] async fn test_duckdb_integration_flow() { - let db_path = duckdb_test_path(); - let db_path_str = db_path.to_string_lossy().to_string(); - - let form = ConnectionForm { - driver: "duckdb".to_string(), - file_path: Some(db_path_str.clone()), - ..Default::default() - }; + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); let driver = DuckdbDriver::connect(&form) .await .expect("Failed to connect to duckdb"); + let _ = driver + .execute_query("DROP VIEW IF EXISTS dbpaw_duckdb_type_probe_v".to_string()) + .await; + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_type_probe".to_string()) + .await; + driver .test_connection() .await @@ -42,14 +29,15 @@ async fn test_duckdb_integration_flow() { .list_databases() .await .expect("list_databases failed"); - assert!(dbs.iter().any(|db| db == "main")); + assert!(!dbs.is_empty(), "list_databases returned empty"); driver .execute_query( - "CREATE TABLE IF NOT EXISTS duck_type_probe (\ + "CREATE TABLE IF NOT EXISTS dbpaw_duckdb_type_probe (\ id INTEGER PRIMARY KEY, \ name VARCHAR, \ - amount DOUBLE\ + amount DOUBLE, \ + created_at TIMESTAMP\ )" .to_string(), ) @@ -58,8 +46,17 @@ async fn test_duckdb_integration_flow() { driver .execute_query( - "INSERT INTO duck_type_probe (id, name, amount) \ - VALUES (1, 'hello', 12.34)" + "CREATE VIEW IF NOT EXISTS dbpaw_duckdb_type_probe_v AS \ + SELECT id, name FROM dbpaw_duckdb_type_probe" + .to_string(), + ) + .await + .expect("create view failed"); + + driver + .execute_query( + "INSERT INTO dbpaw_duckdb_type_probe (id, name, amount, created_at) \ + VALUES (1, 'hello', 12.34, '2026-01-02 03:04:05')" .to_string(), ) .await @@ -67,12 +64,16 @@ async fn test_duckdb_integration_flow() { let tables = driver.list_tables(None).await.expect("list_tables failed"); assert!( - tables.iter().any(|t| t.name == "duck_type_probe"), - "list_tables should include duck_type_probe" + tables.iter().any(|t| t.name == "dbpaw_duckdb_type_probe"), + "list_tables should include dbpaw_duckdb_type_probe" + ); + assert!( + tables.iter().any(|t| t.name == "dbpaw_duckdb_type_probe_v"), + "list_tables should include dbpaw_duckdb_type_probe_v" ); let metadata = driver - .get_table_metadata("main".to_string(), "duck_type_probe".to_string()) + .get_table_metadata("main".to_string(), "dbpaw_duckdb_type_probe".to_string()) .await .expect("get_table_metadata failed"); assert!( @@ -81,7 +82,7 @@ async fn test_duckdb_integration_flow() { ); let ddl = driver - .get_table_ddl("main".to_string(), "duck_type_probe".to_string()) + .get_table_ddl("main".to_string(), "dbpaw_duckdb_type_probe".to_string()) .await .expect("get_table_ddl failed"); assert!( @@ -90,7 +91,11 @@ async fn test_duckdb_integration_flow() { ); let result = driver - .execute_query("SELECT id, name, amount FROM duck_type_probe WHERE id = 1".to_string()) + .execute_query( + "SELECT id, name, amount, created_at \ + FROM dbpaw_duckdb_type_probe WHERE id = 1" + .to_string(), + ) .await .expect("select typed row failed"); assert_eq!(result.row_count, 1); @@ -98,14 +103,827 @@ async fn test_duckdb_integration_flow() { .data .first() .expect("typed result should include at least one row"); - assert_eq!(row["id"], serde_json::Value::String("1".to_string())); + let id_value = row.get("id").expect("id should exist"); + assert!( + id_value == &serde_json::Value::String("1".to_string()) + || id_value == &serde_json::Value::Number(serde_json::Number::from(1)), + "unexpected id value: {:?}", + id_value + ); assert_eq!(row["name"], serde_json::Value::String("hello".to_string())); assert!(row.get("amount").is_some(), "amount should exist"); + assert!(row.get("created_at").is_some(), "created_at should exist"); + + let _ = driver + .execute_query("DROP VIEW IF EXISTS dbpaw_duckdb_type_probe_v".to_string()) + .await; + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_type_probe".to_string()) + .await; + driver.close().await; + + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_get_table_data_supports_pagination_sort_filter_and_order_by() { + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); + let driver = DuckdbDriver::connect(&form) + .await + .expect("Failed to connect to duckdb"); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_grid_probe".to_string()) + .await; + + driver + .execute_query( + "CREATE TABLE dbpaw_duckdb_grid_probe (id INTEGER PRIMARY KEY, name VARCHAR, score INTEGER)" + .to_string(), + ) + .await + .expect("create dbpaw_duckdb_grid_probe failed"); + driver + .execute_query( + "INSERT INTO dbpaw_duckdb_grid_probe (id, name, score) VALUES \ + (1, 'alpha', 10), (2, 'beta', 20), (3, 'gamma', 30), (4, 'delta', 40)" + .to_string(), + ) + .await + .expect("insert dbpaw_duckdb_grid_probe failed"); + + let page1 = driver + .get_table_data( + "main".to_string(), + "dbpaw_duckdb_grid_probe".to_string(), + 1, + 2, + Some("score".to_string()), + Some("desc".to_string()), + None, + None, + ) + .await + .expect("get_table_data page1 failed"); + assert_eq!(page1.total, 4); + assert_eq!(page1.data.len(), 2); + assert_eq!( + page1.data[0]["name"], + serde_json::Value::String("delta".to_string()) + ); + + let filtered = driver + .get_table_data( + "main".to_string(), + "dbpaw_duckdb_grid_probe".to_string(), + 1, + 10, + None, + None, + Some("score >= 20".to_string()), + None, + ) + .await + .expect("get_table_data with filter failed"); + assert_eq!(filtered.total, 3); + + let ordered = driver + .get_table_data( + "main".to_string(), + "dbpaw_duckdb_grid_probe".to_string(), + 1, + 1, + Some("id".to_string()), + Some("asc".to_string()), + None, + Some("name DESC".to_string()), + ) + .await + .expect("get_table_data with order_by failed"); + assert_eq!(ordered.total, 4); + assert_eq!(ordered.data.len(), 1); + assert_eq!( + ordered.data[0]["name"], + serde_json::Value::String("gamma".to_string()) + ); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_grid_probe".to_string()) + .await; + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_get_table_data_rejects_invalid_sort_column() { + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); + let driver = DuckdbDriver::connect(&form) + .await + .expect("Failed to connect to duckdb"); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_invalid_sort_probe".to_string()) + .await; + + driver + .execute_query( + "CREATE TABLE dbpaw_duckdb_invalid_sort_probe (id INTEGER PRIMARY KEY)".to_string(), + ) + .await + .expect("create dbpaw_duckdb_invalid_sort_probe failed"); + + let result = driver + .get_table_data( + "main".to_string(), + "dbpaw_duckdb_invalid_sort_probe".to_string(), + 1, + 10, + Some("id desc".to_string()), + Some("desc".to_string()), + None, + None, + ) + .await; + let err = result.expect_err("invalid sort column should return error"); + assert!( + err.contains("[VALIDATION_ERROR] Invalid sort column name"), + "unexpected error: {}", + err + ); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_invalid_sort_probe".to_string()) + .await; + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_table_structure_and_schema_overview() { + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); + let driver = DuckdbDriver::connect(&form) + .await + .expect("Failed to connect to duckdb"); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_overview_probe".to_string()) + .await; + + driver + .execute_query( + "CREATE TABLE dbpaw_duckdb_overview_probe (id INTEGER PRIMARY KEY, label VARCHAR NOT NULL)" + .to_string(), + ) + .await + .expect("create dbpaw_duckdb_overview_probe failed"); + + let structure = driver + .get_table_structure( + "main".to_string(), + "dbpaw_duckdb_overview_probe".to_string(), + ) + .await + .expect("get_table_structure failed"); + assert!( + structure + .columns + .iter() + .any(|c| c.name == "id" && c.primary_key), + "table structure should include primary key id" + ); + assert!( + structure.columns.iter().any(|c| c.name == "label"), + "table structure should include label" + ); + + let overview = driver + .get_schema_overview(Some("main".to_string())) + .await + .expect("get_schema_overview failed"); + assert!( + overview + .tables + .iter() + .any(|t| t.schema == "main" && t.name == "dbpaw_duckdb_overview_probe"), + "schema overview should include main.dbpaw_duckdb_overview_probe" + ); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_overview_probe".to_string()) + .await; + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_metadata_includes_empty_indexes_and_foreign_keys() { + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); + let driver = DuckdbDriver::connect(&form) + .await + .expect("Failed to connect to duckdb"); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_meta_probe".to_string()) + .await; + + driver + .execute_query( + "CREATE TABLE dbpaw_duckdb_meta_probe (id INTEGER PRIMARY KEY, name VARCHAR)" + .to_string(), + ) + .await + .expect("create dbpaw_duckdb_meta_probe failed"); + + let metadata = driver + .get_table_metadata("main".to_string(), "dbpaw_duckdb_meta_probe".to_string()) + .await + .expect("get_table_metadata failed"); + assert!( + metadata.columns.iter().any(|c| c.name == "id"), + "metadata should include id column" + ); + assert!( + metadata.indexes.is_empty(), + "duckdb metadata indexes should be empty for now" + ); + assert!( + metadata.foreign_keys.is_empty(), + "duckdb metadata foreign_keys should be empty for now" + ); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_meta_probe".to_string()) + .await; + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_boolean_and_json_type_mapping_regression() { + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); + let driver = DuckdbDriver::connect(&form) + .await + .expect("Failed to connect to duckdb"); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_bool_json_probe".to_string()) + .await; + + driver + .execute_query( + "CREATE TABLE dbpaw_duckdb_bool_json_probe (id INTEGER PRIMARY KEY, flag BOOLEAN, meta VARCHAR)" + .to_string(), + ) + .await + .expect("create dbpaw_duckdb_bool_json_probe failed"); + driver + .execute_query( + "INSERT INTO dbpaw_duckdb_bool_json_probe (id, flag, meta) VALUES \ + (1, true, '{\"tier\":\"gold\"}')" + .to_string(), + ) + .await + .expect("insert dbpaw_duckdb_bool_json_probe failed"); + + let query_result = driver + .execute_query( + "SELECT flag, meta FROM dbpaw_duckdb_bool_json_probe WHERE id = 1".to_string(), + ) + .await + .expect("select bool/json probe row failed"); + assert_eq!(query_result.row_count, 1); + let query_row = query_result.data.first().expect("query row should exist"); + let query_flag = query_row + .get("flag") + .expect("flag should exist in query result"); + assert!( + query_flag == &serde_json::Value::Bool(true) + || query_flag == &serde_json::Value::Number(serde_json::Number::from(1)) + || query_flag == &serde_json::Value::String("true".to_string()), + "unexpected query flag value: {:?}", + query_flag + ); + assert!(query_row.get("meta").is_some(), "meta should exist"); + + let table_data = driver + .get_table_data( + "main".to_string(), + "dbpaw_duckdb_bool_json_probe".to_string(), + 1, + 10, + None, + None, + None, + None, + ) + .await + .expect("get_table_data for dbpaw_duckdb_bool_json_probe failed"); + assert_eq!(table_data.total, 1); + let grid_row = table_data.data.first().expect("table row should exist"); + let grid_flag = grid_row + .get("flag") + .expect("flag should exist in table_data result"); + assert!( + grid_flag == &serde_json::Value::Bool(true) + || grid_flag == &serde_json::Value::Number(serde_json::Number::from(1)) + || grid_flag == &serde_json::Value::String("true".to_string()), + "unexpected grid flag value: {:?}", + grid_flag + ); + assert!( + grid_row.get("meta").is_some(), + "meta should exist in table_data" + ); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_bool_json_probe".to_string()) + .await; + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_execute_query_reports_affected_rows_for_update_delete() { + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); + let driver = DuckdbDriver::connect(&form) + .await + .expect("Failed to connect to duckdb"); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_affected_rows_probe".to_string()) + .await; + driver + .execute_query( + "CREATE TABLE dbpaw_duckdb_affected_rows_probe (id INTEGER PRIMARY KEY, name VARCHAR)" + .to_string(), + ) + .await + .expect("create affected_rows probe table failed"); + + let inserted = driver + .execute_query( + "INSERT INTO dbpaw_duckdb_affected_rows_probe (id, name) VALUES (1, 'a'), (2, 'b')" + .to_string(), + ) + .await + .expect("insert affected_rows probe rows failed"); + assert_eq!(inserted.row_count, 2); + + let updated = driver + .execute_query( + "UPDATE dbpaw_duckdb_affected_rows_probe SET name = 'bb' WHERE id = 2".to_string(), + ) + .await + .expect("update affected_rows probe row failed"); + assert_eq!(updated.row_count, 1); + + let deleted = driver + .execute_query( + "DELETE FROM dbpaw_duckdb_affected_rows_probe WHERE id IN (1, 2)".to_string(), + ) + .await + .expect("delete affected_rows probe rows failed"); + assert_eq!(deleted.row_count, 2); let _ = driver - .execute_query("DROP TABLE IF EXISTS duck_type_probe".to_string()) + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_affected_rows_probe".to_string()) .await; driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_transaction_commit_and_rollback() { + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); + let driver = DuckdbDriver::connect(&form) + .await + .expect("Failed to connect to duckdb"); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_txn_probe".to_string()) + .await; + driver + .execute_query( + "CREATE TABLE dbpaw_duckdb_txn_probe (id INTEGER PRIMARY KEY, name VARCHAR)" + .to_string(), + ) + .await + .expect("create duckdb txn probe table failed"); + + driver + .execute_query( + "BEGIN TRANSACTION; \ + INSERT INTO dbpaw_duckdb_txn_probe (id, name) VALUES (1, 'rolled_back'); \ + ROLLBACK;" + .to_string(), + ) + .await + .expect("rollback flow failed"); + + let rolled_back = driver + .execute_query("SELECT COUNT(*) AS c FROM dbpaw_duckdb_txn_probe WHERE id = 1".to_string()) + .await + .expect("count after rollback failed"); + let rolled_back_count = rolled_back.data[0]["c"] + .as_str() + .expect("rollback count should be string") + .parse::() + .expect("rollback count should be numeric"); + assert_eq!(rolled_back_count, 0); + + driver + .execute_query( + "BEGIN TRANSACTION; \ + INSERT INTO dbpaw_duckdb_txn_probe (id, name) VALUES (2, 'committed'); \ + COMMIT;" + .to_string(), + ) + .await + .expect("commit flow failed"); + + let committed = driver + .execute_query("SELECT COUNT(*) AS c FROM dbpaw_duckdb_txn_probe WHERE id = 2".to_string()) + .await + .expect("count after commit failed"); + let committed_count = committed.data[0]["c"] + .as_str() + .expect("commit count should be string") + .parse::() + .expect("commit count should be numeric"); + assert_eq!(committed_count, 1); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_txn_probe".to_string()) + .await; + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_error_handling_for_sql_error() { + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); + let driver = DuckdbDriver::connect(&form) + .await + .expect("Failed to connect to duckdb"); + let err = driver + .execute_query("SELECT * FROM __dbpaw_table_not_exists".to_string()) + .await + .expect_err("invalid SQL should return query error"); + assert!( + err.contains("[QUERY_ERROR]"), + "unexpected error shape: {}", + err + ); + + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_connection_failure_with_invalid_path() { + let form = dbpaw_lib::models::ConnectionForm { + driver: "duckdb".to_string(), + file_path: Some("/nonexistent/path/that/cannot/be/created/dbpaw_test.duckdb".to_string()), + ..Default::default() + }; + + let err = match DuckdbDriver::connect(&form).await { + Ok(_) => panic!("invalid path should fail"), + Err(err) => err, + }; + assert!( + err.starts_with("[CONN_FAILED]"), + "unexpected error: {}", + err + ); + assert!(!err.trim().is_empty(), "error message should not be empty"); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_database_locked_error() { + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); + + let driver1 = DuckdbDriver::connect(&form) + .await + .expect("First connection should succeed"); + + let driver2_result = DuckdbDriver::connect(&form).await; + + match driver2_result { + Ok(driver2) => { + driver2.close().await; + } + Err(err) => { + assert!( + err.starts_with("[CONN_FAILED]") || err.contains("busy") || err.contains("locked"), + "expected lock/busy error, got: {}", + err + ); + } + } + + driver1.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_batch_insert_and_batch_execute_flow() { + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); + let driver = DuckdbDriver::connect(&form) + .await + .expect("Failed to connect to duckdb"); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_batch_probe".to_string()) + .await; + driver + .execute_query( + "CREATE TABLE dbpaw_duckdb_batch_probe (id INTEGER PRIMARY KEY, category VARCHAR, score INTEGER)" + .to_string(), + ) + .await + .expect("create batch probe table failed"); + + let value_rows: Vec = (1..=50) + .map(|id| { + let category = if id <= 25 { "alpha" } else { "beta" }; + format!("({}, '{}', {})", id, category, id) + }) + .collect(); + let insert_sql = format!( + "INSERT INTO dbpaw_duckdb_batch_probe (id, category, score) VALUES {}", + value_rows.join(", ") + ); + let inserted = driver + .execute_query(insert_sql) + .await + .expect("batch insert failed"); + assert_eq!(inserted.row_count, 50); + + let batch_sqls = vec![ + "UPDATE dbpaw_duckdb_batch_probe SET score = score + 100 WHERE id <= 10".to_string(), + "UPDATE dbpaw_duckdb_batch_probe SET category = 'gamma' WHERE id BETWEEN 30 AND 40" + .to_string(), + "DELETE FROM dbpaw_duckdb_batch_probe WHERE id IN (3, 6, 9, 12, 15)".to_string(), + ]; + let mut affected = Vec::new(); + for sql in batch_sqls { + let result = driver + .execute_query(sql) + .await + .expect("batch execute statement failed"); + affected.push(result.row_count); + } + assert_eq!(affected, vec![10, 11, 5]); + + let check_total = driver + .execute_query("SELECT COUNT(*) AS c FROM dbpaw_duckdb_batch_probe".to_string()) + .await + .expect("count after batch execute failed"); + let total = check_total.data[0]["c"] + .as_str() + .expect("count should be string") + .parse::() + .expect("count should be numeric"); + assert_eq!(total, 45); + + let check_gamma = driver + .execute_query( + "SELECT COUNT(*) AS c FROM dbpaw_duckdb_batch_probe WHERE category = 'gamma'" + .to_string(), + ) + .await + .expect("count gamma rows failed"); + let gamma = check_gamma.data[0]["c"] + .as_str() + .expect("gamma count should be string") + .parse::() + .expect("gamma count should be numeric"); + assert_eq!(gamma, 11); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_batch_probe".to_string()) + .await; + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_large_text_and_blob_round_trip() { + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); + let driver = DuckdbDriver::connect(&form) + .await + .expect("Failed to connect to duckdb"); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_large_field_probe".to_string()) + .await; + driver + .execute_query( + "CREATE TABLE dbpaw_duckdb_large_field_probe (id INTEGER PRIMARY KEY, body TEXT, payload BLOB)" + .to_string(), + ) + .await + .expect("create large field probe table failed"); + + let large_text = "x".repeat(70000); + let blob_data: Vec = (0..4096).map(|i| (i % 256) as u8).collect(); + let blob_hex: String = blob_data.iter().map(|b| format!("{:02x}", b)).collect(); + + driver + .execute_query( + format!( + "INSERT INTO dbpaw_duckdb_large_field_probe (id, body, payload) VALUES (1, '{}', '{}'::BLOB)", + large_text, blob_hex + ) + .to_string(), + ) + .await + .expect("insert large field probe row failed"); + + let result = driver + .execute_query( + "SELECT body, payload FROM dbpaw_duckdb_large_field_probe WHERE id = 1".to_string(), + ) + .await + .expect("select large field probe row failed"); + assert_eq!(result.row_count, 1); + let row = result.data.first().expect("large field row should exist"); + let body = row + .get("body") + .and_then(|v| v.as_str()) + .expect("body should be string"); + assert_eq!(body.len(), 70000); + assert!(row.get("payload").is_some(), "payload should exist"); + + let _ = driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_large_field_probe".to_string()) + .await; + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_concurrent_connections_can_query() { + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); + + let driver = DuckdbDriver::connect(&form) + .await + .expect("Failed to connect to duckdb"); + driver + .execute_query( + "CREATE TABLE IF NOT EXISTS dbpaw_duckdb_concurrent_probe (id INTEGER, value VARCHAR)" + .to_string(), + ) + .await + .expect("create concurrent probe table failed"); + driver + .execute_query("INSERT INTO dbpaw_duckdb_concurrent_probe VALUES (1, 'test')".to_string()) + .await + .expect("insert concurrent probe row failed"); + driver.close().await; + + let mut handles = Vec::new(); + + for i in 0..4 { + let task_db_path = db_path.clone(); + handles.push(tokio::spawn(async move { + let task_form = dbpaw_lib::models::ConnectionForm { + driver: "duckdb".to_string(), + file_path: Some(task_db_path.to_string_lossy().to_string()), + ..Default::default() + }; + let task_driver = DuckdbDriver::connect(&task_form) + .await + .expect("Failed to connect to duckdb in concurrent task"); + tokio::time::sleep(tokio::time::Duration::from_millis(10 * i as u64)).await; + let result = task_driver + .execute_query( + "SELECT * FROM dbpaw_duckdb_concurrent_probe WHERE id = 1".to_string(), + ) + .await; + task_driver.close().await; + result + })); + } + + for handle in handles { + let result = handle.await.expect("concurrent duckdb task panicked"); + let data = result.expect("concurrent duckdb query failed"); + assert_eq!(data.row_count, 1); + assert_eq!( + data.data[0]["value"], + serde_json::Value::String("test".to_string()) + ); + } + + let cleanup_driver = DuckdbDriver::connect(&form) + .await + .expect("Failed to connect for cleanup"); + let _ = cleanup_driver + .execute_query("DROP TABLE IF EXISTS dbpaw_duckdb_concurrent_probe".to_string()) + .await; + cleanup_driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_duckdb_view_can_be_listed_and_queried() { + let (db_path, form) = duckdb_context::duckdb_form_from_test_context(); + let driver = DuckdbDriver::connect(&form) + .await + .expect("Failed to connect to duckdb"); + + let base_table = "dbpaw_duckdb_view_base_probe"; + let view_name = "dbpaw_duckdb_view_probe_v"; + + let _ = driver + .execute_query(format!("DROP VIEW IF EXISTS {}", view_name)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", base_table)) + .await; + + driver + .execute_query( + format!( + "CREATE TABLE {} (id INTEGER PRIMARY KEY, name VARCHAR, score INTEGER)", + base_table + ) + .to_string(), + ) + .await + .expect("create base table for view failed"); + driver + .execute_query( + format!( + "INSERT INTO {} (id, name, score) VALUES (1, 'alice', 10), (2, 'bob', 20)", + base_table + ) + .to_string(), + ) + .await + .expect("insert base rows for view failed"); + driver + .execute_query( + format!( + "CREATE VIEW {} AS SELECT id, name FROM {} WHERE score >= 20", + view_name, base_table + ) + .to_string(), + ) + .await + .expect("create view failed"); + + let tables = driver + .list_tables(Some("main".to_string())) + .await + .expect("list_tables failed"); + assert!( + tables + .iter() + .any(|t| t.name == base_table && t.r#type == "table"), + "list_tables should include base table" + ); + assert!( + tables + .iter() + .any(|t| t.name == view_name && t.r#type == "view"), + "list_tables should include view with type=view" + ); + + let view_rows = driver + .execute_query(format!("SELECT id, name FROM {} ORDER BY id", view_name).to_string()) + .await + .expect("select from view failed"); + assert_eq!(view_rows.row_count, 1); + let row = view_rows.data.first().expect("view row should exist"); + let id_matches = row["id"] == serde_json::Value::Number(2.into()) + || row["id"] == serde_json::Value::String("2".to_string()); + assert!(id_matches, "unexpected id payload: {}", row["id"]); + assert_eq!(row["name"], serde_json::Value::String("bob".to_string())); + + let _ = driver + .execute_query(format!("DROP VIEW IF EXISTS {}", view_name)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", base_table)) + .await; + driver.close().await; let _ = std::fs::remove_file(db_path); } diff --git a/src-tauri/tests/mariadb_integration.rs b/src-tauri/tests/mariadb_integration.rs index 7386a01..6061f20 100644 --- a/src-tauri/tests/mariadb_integration.rs +++ b/src-tauri/tests/mariadb_integration.rs @@ -1,127 +1,114 @@ +#[path = "common/mariadb_context.rs"] +mod mariadb_context; + use dbpaw_lib::db::drivers::mysql::MysqlDriver; use dbpaw_lib::db::drivers::DatabaseDriver; -use dbpaw_lib::models::ConnectionForm; -use std::env; +use testcontainers::clients::Cli; #[tokio::test] #[ignore] async fn test_mariadb_integration_flow() { - let host = env::var("MARIADB_HOST").unwrap_or_else(|_| "localhost".to_string()); - let port = env::var("MARIADB_PORT") - .unwrap_or_else(|_| "3306".to_string()) - .parse() - .unwrap(); - let username = env::var("MARIADB_USER").unwrap_or_else(|_| "root".to_string()); - let password = env::var("MARIADB_PASSWORD").unwrap_or_else(|_| "123456".to_string()); - let database = env::var("MARIADB_DB").ok(); - - let form = ConnectionForm { - driver: "mariadb".to_string(), - host: Some(host), - port: Some(port), - username: Some(username), - password: Some(password), - database: database.clone(), - ..Default::default() - }; - - let driver = MysqlDriver::connect(&form) - .await - .expect("Failed to connect"); - - driver.test_connection().await.expect("Connection failed"); + let docker = (!mariadb_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mariadb_context::mariadb_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MARIADB_DB or container default database should be present"); + let driver: MysqlDriver = + mariadb_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + driver + .test_connection() + .await + .expect("test_connection failed"); let dbs = driver .list_databases() .await - .expect("Failed to list databases"); - assert!(!dbs.is_empty()); - - if let Some(db_name) = database { - let table = "dbpaw_mariadb_integration"; - let qualified = format!("`{}`.`{}`", db_name, table); - - driver - .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) - .await - .expect("drop before test failed"); - - driver - .execute_query(format!( - "CREATE TABLE {} (id INT PRIMARY KEY, name VARCHAR(50))", - qualified - )) - .await - .expect("create table failed"); - - driver - .execute_query(format!( - "INSERT INTO {} (id, name) VALUES (1, 'MariaDB')", - qualified - )) - .await - .expect("insert failed"); - - let result = driver - .execute_query(format!("SELECT name FROM {} WHERE id = 1", qualified)) - .await - .expect("select failed"); - assert_eq!(result.row_count, 1); - assert_eq!( - result.data[0].get("name").and_then(|v| v.as_str()), - Some("MariaDB") - ); - - driver - .execute_query(format!("DROP TABLE {}", qualified)) - .await - .expect("drop after test failed"); - } -} + .expect("list_databases failed"); + assert!(!dbs.is_empty(), "list_databases returned empty"); + assert!( + dbs.iter().any(|db| db == &database), + "list_databases should include {}", + database + ); -#[tokio::test] -#[ignore] -async fn test_mariadb_show_create_and_information_schema_compat() { - let host = env::var("MARIADB_HOST").unwrap_or_else(|_| "localhost".to_string()); - let port = env::var("MARIADB_PORT") - .unwrap_or_else(|_| "3306".to_string()) - .parse() - .unwrap(); - let username = env::var("MARIADB_USER").unwrap_or_else(|_| "root".to_string()); - let password = env::var("MARIADB_PASSWORD").unwrap_or_else(|_| "123456".to_string()); - let database = env::var("MARIADB_DB").unwrap_or_else(|_| "test".to_string()); - - let form = ConnectionForm { - driver: "mariadb".to_string(), - host: Some(host), - port: Some(port), - username: Some(username), - password: Some(password), - database: Some(database.clone()), - ..Default::default() - }; - - let driver = MysqlDriver::connect(&form) - .await - .expect("Failed to connect"); - let table = "dbpaw_mariadb_meta"; - let qualified = format!("`{}`.`{}`", database, table); + let table_name = "dbpaw_mariadb_type_probe"; + let view_name = "dbpaw_mariadb_type_probe_v"; + let qualified = format!("`{}`.`{}`", database, table_name); + let qualified_view = format!("`{}`.`{}`", database, view_name); - driver + let _ = driver + .execute_query(format!("DROP VIEW IF EXISTS {}", qualified_view)) + .await; + let _ = driver .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) - .await - .expect("drop before test failed"); + .await; driver .execute_query(format!( - "CREATE TABLE {} (id INT PRIMARY KEY, payload VARCHAR(50))", + "CREATE TABLE {} (\ + id INT PRIMARY KEY, \ + name VARCHAR(50), \ + amount DECIMAL(10,2), \ + payload VARBINARY(16), \ + created_at DATETIME\ + )", qualified )) .await .expect("create table failed"); + driver + .execute_query(format!( + "CREATE VIEW {} AS SELECT id, name FROM {}", + qualified_view, qualified + )) + .await + .expect("create view failed"); + + driver + .execute_query(format!( + "INSERT INTO {} (id, name, amount, payload, created_at) \ + VALUES (1, 'hello', 12.34, UNHEX('DEADBEEF'), '2026-01-02 03:04:05')", + qualified + )) + .await + .expect("insert failed"); + + let tables = driver + .list_tables(Some(database.clone())) + .await + .expect("list_tables failed"); + assert!( + tables.iter().any(|t| t.name == table_name), + "list_tables should include {}", + table_name + ); + assert!( + tables.iter().any(|t| t.name == view_name), + "list_tables should include {}", + view_name + ); + + let metadata = driver + .get_table_metadata(database.clone(), table_name.to_string()) + .await + .expect("get_table_metadata failed"); + assert!( + metadata + .columns + .iter() + .any(|c| c.name == "id" && c.primary_key), + "metadata should include primary key id" + ); + assert!( + metadata.columns.iter().any(|c| c.name == "payload"), + "metadata should include payload column" + ); + let ddl = driver - .get_table_ddl(database.clone(), table.to_string()) + .get_table_ddl(database.clone(), table_name.to_string()) .await .expect("get_table_ddl failed"); assert!( @@ -129,18 +116,399 @@ async fn test_mariadb_show_create_and_information_schema_compat() { "DDL should contain CREATE TABLE" ); + let result = driver + .execute_query(format!( + "SELECT id, name, amount, created_at FROM {} WHERE id = 1", + qualified + )) + .await + .expect("select typed row failed"); + assert_eq!(result.row_count, 1); + let row = result + .data + .first() + .expect("typed result should include at least one row"); + let id_value = row.get("id").expect("id should exist"); + assert!( + id_value == &serde_json::Value::String("1".to_string()) + || id_value == &serde_json::Value::Number(serde_json::Number::from(1)), + "unexpected id value: {:?}", + id_value + ); + assert_eq!(row["name"], serde_json::Value::String("hello".to_string())); + assert!(row.get("amount").is_some(), "amount should exist"); + assert!(row.get("created_at").is_some(), "created_at should exist"); + + let _ = driver + .execute_query(format!("DROP VIEW IF EXISTS {}", qualified_view)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mariadb_get_table_data_supports_pagination_sort_filter_and_order_by() { + let docker = (!mariadb_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mariadb_context::mariadb_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MARIADB_DB or container default database should be present"); + let driver: MysqlDriver = + mariadb_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let table_name = "dbpaw_mariadb_grid_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name VARCHAR(30), score INT)", + qualified + )) + .await + .expect("create dbpaw_mariadb_grid_probe failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name, score) VALUES \ + (1, 'alpha', 10), (2, 'beta', 20), (3, 'gamma', 30), (4, 'delta', 40)", + qualified + )) + .await + .expect("insert dbpaw_mariadb_grid_probe failed"); + + let page1 = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 2, + Some("score".to_string()), + Some("desc".to_string()), + None, + None, + ) + .await + .expect("get_table_data page1 failed"); + assert_eq!(page1.total, 4); + assert_eq!(page1.data.len(), 2); + assert_eq!( + page1.data[0]["name"], + serde_json::Value::String("delta".to_string()) + ); + + let filtered = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 10, + None, + None, + Some("score >= 20".to_string()), + None, + ) + .await + .expect("get_table_data with filter failed"); + assert_eq!(filtered.total, 3); + + let ordered = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 1, + Some("id".to_string()), + Some("asc".to_string()), + None, + Some("name DESC".to_string()), + ) + .await + .expect("get_table_data with order_by failed"); + assert_eq!(ordered.total, 4); + assert_eq!(ordered.data.len(), 1); + assert_eq!( + ordered.data[0]["name"], + serde_json::Value::String("gamma".to_string()) + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mariadb_get_table_data_rejects_invalid_sort_column() { + let docker = (!mariadb_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mariadb_context::mariadb_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MARIADB_DB or container default database should be present"); + let driver: MysqlDriver = + mariadb_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let table_name = "dbpaw_mariadb_invalid_sort_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!("CREATE TABLE {} (id INT PRIMARY KEY)", qualified)) + .await + .expect("create dbpaw_mariadb_invalid_sort_probe failed"); + + let result = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 10, + Some("id desc".to_string()), + Some("desc".to_string()), + None, + None, + ) + .await; + let err = result.expect_err("invalid sort column should return error"); + assert!( + err.contains("[VALIDATION_ERROR] Invalid sort column name"), + "unexpected error: {}", + err + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mariadb_table_structure_and_schema_overview() { + let docker = (!mariadb_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mariadb_context::mariadb_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MARIADB_DB or container default database should be present"); + let driver: MysqlDriver = + mariadb_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let table_name = "dbpaw_mariadb_overview_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, label VARCHAR(50) NOT NULL)", + qualified + )) + .await + .expect("create dbpaw_mariadb_overview_probe failed"); + + let structure = driver + .get_table_structure(database.clone(), table_name.to_string()) + .await + .expect("get_table_structure failed"); + assert!( + structure.columns.iter().any(|c| c.name == "id"), + "table structure should include id" + ); + assert!( + structure.columns.iter().any(|c| c.name == "label"), + "table structure should include label" + ); + let overview = driver .get_schema_overview(Some(database.clone())) .await .expect("get_schema_overview failed"); assert!( - overview.tables.iter().any(|t| t.name == table), - "schema overview should include {}", - table + overview + .tables + .iter() + .any(|t| t.schema == database && t.name == table_name), + "schema overview should include {}.{}", + database, + table_name ); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mariadb_metadata_includes_indexes_and_foreign_keys() { + let docker = (!mariadb_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mariadb_context::mariadb_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MARIADB_DB or container default database should be present"); + let driver: MysqlDriver = + mariadb_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let parent = "dbpaw_mariadb_parent_meta_probe"; + let child = "dbpaw_mariadb_child_meta_probe"; + let parent_qualified = format!("`{}`.`{}`", database, parent); + let child_qualified = format!("`{}`.`{}`", database, child); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", child_qualified)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", parent_qualified)) + .await; + driver - .execute_query(format!("DROP TABLE {}", qualified)) + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY)", + parent_qualified + )) .await - .expect("drop after test failed"); + .expect("create parent table failed"); + driver + .execute_query(format!( + "CREATE TABLE {} (\ + id INT PRIMARY KEY, \ + parent_id INT NOT NULL, \ + name VARCHAR(30), \ + INDEX idx_mariadb_child_name (name), \ + CONSTRAINT fk_mariadb_child_parent FOREIGN KEY (parent_id) REFERENCES {}(id)\ + )", + child_qualified, parent_qualified + )) + .await + .expect("create child table with fk failed"); + + let metadata = driver + .get_table_metadata(database.clone(), child.to_string()) + .await + .expect("get_table_metadata failed"); + assert!( + metadata + .indexes + .iter() + .any(|i| i.name == "idx_mariadb_child_name" && i.columns.contains(&"name".to_string())), + "metadata should include idx_mariadb_child_name" + ); + assert!( + metadata + .foreign_keys + .iter() + .any(|fk| fk.column == "parent_id" && fk.referenced_table == parent), + "metadata should include FK parent_id -> {}(id)", + parent + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", child_qualified)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", parent_qualified)) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mariadb_boolean_and_json_type_mapping_regression() { + let docker = (!mariadb_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mariadb_context::mariadb_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MARIADB_DB or container default database should be present"); + let driver: MysqlDriver = + mariadb_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let table_name = "dbpaw_mariadb_bool_json_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, flag BOOLEAN, meta JSON)", + qualified + )) + .await + .expect("create bool/json probe table failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, flag, meta) VALUES (1, 1, '{{\"tier\": \"gold\"}}')", + qualified + )) + .await + .expect("insert bool/json probe row failed"); + + let query_result = driver + .execute_query(format!( + "SELECT flag, JSON_UNQUOTE(JSON_EXTRACT(meta, '$.tier')) AS tier FROM {} WHERE id = 1", + qualified + )) + .await + .expect("select bool/json row failed"); + assert_eq!(query_result.row_count, 1); + let query_row = query_result.data.first().expect("query row should exist"); + let query_flag = query_row + .get("flag") + .expect("flag should exist in query result"); + assert!( + query_flag == &serde_json::Value::Bool(true) + || query_flag == &serde_json::Value::Number(serde_json::Number::from(1)), + "unexpected query flag value: {:?}", + query_flag + ); + assert_eq!( + query_row["tier"], + serde_json::Value::String("gold".to_string()) + ); + + let table_data = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 10, + None, + None, + None, + None, + ) + .await + .expect("get_table_data for bool/json probe failed"); + assert_eq!(table_data.total, 1); + let grid_row = table_data.data.first().expect("table row should exist"); + let grid_flag = grid_row + .get("flag") + .expect("flag should exist in table_data result"); + assert!( + grid_flag == &serde_json::Value::Bool(true) + || grid_flag == &serde_json::Value::Number(serde_json::Number::from(1)), + "unexpected grid flag value: {:?}", + grid_flag + ); + assert!( + grid_row.get("meta").is_some(), + "meta should exist in table_data" + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver.close().await; } diff --git a/src-tauri/tests/mssql_integration.rs b/src-tauri/tests/mssql_integration.rs index c23ca22..da68345 100644 --- a/src-tauri/tests/mssql_integration.rs +++ b/src-tauri/tests/mssql_integration.rs @@ -1,33 +1,31 @@ +#[path = "common/mssql_context.rs"] +mod mssql_context; + use dbpaw_lib::db::drivers::mssql::MssqlDriver; use dbpaw_lib::db::drivers::DatabaseDriver; -use dbpaw_lib::models::ConnectionForm; -use std::env; +use testcontainers::clients::Cli; + +fn scalar_to_i64(value: &serde_json::Value) -> i64 { + if let Some(v) = value.as_i64() { + return v; + } + value + .as_str() + .and_then(|v| v.parse::().ok()) + .expect("scalar should be i64 or parseable string") +} #[tokio::test] #[ignore] async fn test_mssql_integration_flow() { - let host = env::var("MSSQL_HOST").unwrap_or_else(|_| "localhost".to_string()); - let port = env::var("MSSQL_PORT") - .unwrap_or_else(|_| "1433".to_string()) - .parse() - .expect("MSSQL_PORT should be a number"); - let username = env::var("MSSQL_USER").unwrap_or_else(|_| "sa".to_string()); - let password = env::var("MSSQL_PASSWORD").unwrap_or_default(); - let database = env::var("MSSQL_DB").unwrap_or_else(|_| "master".to_string()); - - let form = ConnectionForm { - driver: "mssql".to_string(), - host: Some(host), - port: Some(port), - username: Some(username), - password: Some(password), - database: Some(database.clone()), - ..Default::default() - }; - - let driver = MssqlDriver::connect(&form) - .await - .expect("Failed to connect to SQL Server"); + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MSSQL_DB or container default database should be present"); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; driver .test_connection() @@ -39,147 +37,1107 @@ async fn test_mssql_integration_flow() { .await .expect("list_databases failed"); assert!(!databases.is_empty(), "list_databases returned empty"); + assert!( + databases.iter().any(|db| db == &database), + "list_databases should include {}", + database + ); - let _tables = driver.list_tables(None).await.expect("list_tables failed"); + let table_name = "dbpaw_mssql_type_probe"; + let qualified = format!("[dbo].[{}]", table_name); + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + + driver + .execute_query(format!( + "CREATE TABLE {} (\ + id INT PRIMARY KEY, \ + name NVARCHAR(50), \ + amount DECIMAL(10,2), \ + payload VARBINARY(16), \ + created_at DATETIME2\ + )", + qualified + )) + .await + .expect("create table failed"); + + driver + .execute_query(format!( + "INSERT INTO {} (id, name, amount, payload, created_at) \ + VALUES (1, N'hello', 12.34, 0xDEADBEEF, '2026-01-02T03:04:05')", + qualified + )) + .await + .expect("insert failed"); + + let tables = driver.list_tables(None).await.expect("list_tables failed"); + assert!( + tables + .iter() + .any(|t| t.schema == "dbo" && t.name == table_name), + "list_tables should include dbo.{}", + table_name + ); + + let metadata = driver + .get_table_metadata("dbo".to_string(), table_name.to_string()) + .await + .expect("get_table_metadata failed"); + assert!( + metadata + .columns + .iter() + .any(|c| c.name == "id" && c.primary_key), + "metadata should include primary key id" + ); + assert!( + metadata.columns.iter().any(|c| c.name == "payload"), + "metadata should include payload column" + ); + + let ddl = driver + .get_table_ddl("dbo".to_string(), table_name.to_string()) + .await + .expect("get_table_ddl failed"); + assert!( + ddl.to_uppercase().contains("CREATE TABLE"), + "DDL should contain CREATE TABLE" + ); let result = driver - .execute_query("SELECT TOP 1 name FROM sys.databases".to_string()) + .execute_query(format!( + "SELECT id, name, amount, created_at FROM {} WHERE id = 1", + qualified + )) .await - .expect("execute_query failed"); - assert!(result.row_count >= 1); + .expect("select typed row failed"); + assert_eq!(result.row_count, 1); + let row = result + .data + .first() + .expect("typed result should include at least one row"); + let id_value = row.get("id").expect("id should exist"); + assert!( + id_value == &serde_json::Value::String("1".to_string()) + || id_value == &serde_json::Value::Number(serde_json::Number::from(1)), + "unexpected id value: {:?}", + id_value + ); + assert_eq!(row["name"], serde_json::Value::String("hello".to_string())); + assert!(row.get("amount").is_some(), "amount should exist"); + assert!(row.get("created_at").is_some(), "created_at should exist"); + + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver.close().await; +} +#[tokio::test] +#[ignore] +async fn test_mssql_get_table_data_supports_pagination_sort_filter_and_order_by() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; + + let table_name = "dbpaw_mssql_grid_probe"; + let qualified = format!("[dbo].[{}]", table_name); + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; driver - .execute_query( - "IF OBJECT_ID('dbo.dbpaw_type_probe', 'U') IS NOT NULL DROP TABLE dbo.dbpaw_type_probe;" - .to_string(), + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name NVARCHAR(30), score INT)", + qualified + )) + .await + .expect("create dbpaw_mssql_grid_probe failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name, score) VALUES \ + (1, N'alpha', 10), (2, N'beta', 20), (3, N'gamma', 30), (4, N'delta', 40)", + qualified + )) + .await + .expect("insert dbpaw_mssql_grid_probe failed"); + + let page1 = driver + .get_table_data( + "dbo".to_string(), + table_name.to_string(), + 1, + 2, + Some("score".to_string()), + Some("desc".to_string()), + None, + None, ) .await - .expect("drop table failed"); + .expect("get_table_data page1 failed"); + assert_eq!(page1.total, 4); + assert_eq!(page1.data.len(), 2); + assert_eq!( + page1.data[0]["name"], + serde_json::Value::String("delta".to_string()) + ); - driver - .execute_query( - "CREATE TABLE dbo.dbpaw_type_probe (id INT PRIMARY KEY, flag BIT, amount DECIMAL(10,2), created_at DATETIME2);" - .to_string(), + let filtered = driver + .get_table_data( + "dbo".to_string(), + table_name.to_string(), + 1, + 10, + None, + None, + Some("score >= 20".to_string()), + None, ) .await - .expect("create table failed"); + .expect("get_table_data with filter failed"); + assert_eq!(filtered.total, 3); - driver - .execute_query( - "INSERT INTO dbo.dbpaw_type_probe (id, flag, amount, created_at) VALUES (1, 1, 12.34, '2026-01-02T03:04:05');" - .to_string(), + let ordered = driver + .get_table_data( + "dbo".to_string(), + table_name.to_string(), + 1, + 1, + Some("id".to_string()), + Some("asc".to_string()), + None, + Some("name DESC".to_string()), ) .await - .expect("insert failed"); + .expect("get_table_data with order_by failed"); + assert_eq!(ordered.total, 4); + assert_eq!(ordered.data.len(), 1); + assert_eq!( + ordered.data[0]["name"], + serde_json::Value::String("gamma".to_string()) + ); + + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mssql_get_table_data_rejects_invalid_sort_column() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; + + let table_name = "dbpaw_mssql_invalid_sort_probe"; + let qualified = format!("[dbo].[{}]", table_name); + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver + .execute_query(format!("CREATE TABLE {} (id INT PRIMARY KEY)", qualified)) + .await + .expect("create dbpaw_mssql_invalid_sort_probe failed"); - let typed_result = driver - .execute_query( - "SELECT TOP 1 flag, amount, created_at FROM dbo.dbpaw_type_probe ORDER BY id DESC" - .to_string(), + let result = driver + .get_table_data( + "dbo".to_string(), + table_name.to_string(), + 1, + 10, + Some("id desc".to_string()), + Some("desc".to_string()), + None, + None, ) + .await; + let err = result.expect_err("invalid sort column should return error"); + assert!( + err.contains("[VALIDATION_ERROR] Invalid sort column name") + || err.contains("Invalid column name"), + "unexpected error: {}", + err + ); + + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mssql_table_structure_and_schema_overview() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; + + let table_name = "dbpaw_mssql_overview_probe"; + let qualified = format!("[dbo].[{}]", table_name); + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, label NVARCHAR(50) NOT NULL)", + qualified + )) .await - .expect("select typed row failed"); + .expect("create dbpaw_mssql_overview_probe failed"); - let row = typed_result - .data - .first() - .expect("typed_result should include at least one row"); - assert_eq!(row["flag"], serde_json::Value::Bool(true)); + let structure = driver + .get_table_structure("dbo".to_string(), table_name.to_string()) + .await + .expect("get_table_structure failed"); + assert!( + structure.columns.iter().any(|c| c.name == "id"), + "table structure should include id" + ); + assert!( + structure.columns.iter().any(|c| c.name == "label"), + "table structure should include label" + ); + + let overview = driver + .get_schema_overview(Some("dbo".to_string())) + .await + .expect("get_schema_overview failed"); + assert!( + overview + .tables + .iter() + .any(|t| t.schema == "dbo" && t.name == table_name), + "schema overview should include dbo.{}", + table_name + ); + + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mssql_metadata_includes_indexes_and_foreign_keys() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; + + let parent = "dbpaw_mssql_parent_meta_probe"; + let child = "dbpaw_mssql_child_meta_probe"; + let parent_qualified = format!("[dbo].[{}]", parent); + let child_qualified = format!("[dbo].[{}]", child); + + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + child, child_qualified + )) + .await; + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + parent, parent_qualified + )) + .await; + + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY)", + parent_qualified + )) + .await + .expect("create parent table failed"); + driver + .execute_query(format!( + "CREATE TABLE {} (\ + id INT PRIMARY KEY, \ + parent_id INT NOT NULL, \ + name NVARCHAR(30), \ + CONSTRAINT fk_mssql_child_parent FOREIGN KEY (parent_id) REFERENCES {}(id)\ + )", + child_qualified, parent_qualified + )) + .await + .expect("create child table with fk failed"); + driver + .execute_query(format!( + "CREATE INDEX idx_mssql_child_name ON {} (name)", + child_qualified + )) + .await + .expect("create index failed"); + + let metadata = driver + .get_table_metadata("dbo".to_string(), child.to_string()) + .await + .expect("get_table_metadata failed"); + assert!( + metadata + .indexes + .iter() + .any(|i| i.name == "idx_mssql_child_name" && i.columns.contains(&"name".to_string())), + "metadata should include idx_mssql_child_name" + ); + assert!( + metadata + .foreign_keys + .iter() + .any(|fk| fk.column == "parent_id" && fk.referenced_table == parent), + "metadata should include FK parent_id -> {}(id)", + parent + ); + + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + child, child_qualified + )) + .await; + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + parent, parent_qualified + )) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mssql_boolean_and_json_type_mapping_regression() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; + + let table_name = "dbpaw_mssql_bool_json_probe"; + let qualified = format!("[dbo].[{}]", table_name); + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, flag BIT, meta NVARCHAR(MAX))", + qualified + )) + .await + .expect("create bool/json probe table failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, flag, meta) VALUES (1, 1, N'{{\"tier\":\"gold\"}}')", + qualified + )) + .await + .expect("insert bool/json probe row failed"); + + let query_result = driver + .execute_query(format!( + "SELECT flag, JSON_VALUE(meta, '$.tier') AS tier FROM {} WHERE id = 1", + qualified + )) + .await + .expect("select bool/json row failed"); + assert_eq!(query_result.row_count, 1); + let query_row = query_result.data.first().expect("query row should exist"); + let query_flag = query_row + .get("flag") + .expect("flag should exist in query result"); + assert!( + query_flag == &serde_json::Value::Bool(true) + || query_flag == &serde_json::Value::Number(serde_json::Number::from(1)), + "unexpected query flag value: {:?}", + query_flag + ); assert_eq!( - row["amount"], - serde_json::Value::String("12.34".to_string()) + query_row["tier"], + serde_json::Value::String("gold".to_string()) + ); + + let table_data = driver + .get_table_data( + "dbo".to_string(), + table_name.to_string(), + 1, + 10, + None, + None, + None, + None, + ) + .await + .expect("get_table_data for bool/json probe failed"); + assert_eq!(table_data.total, 1); + let grid_row = table_data.data.first().expect("table row should exist"); + let grid_flag = grid_row + .get("flag") + .expect("flag should exist in table_data result"); + assert!( + grid_flag == &serde_json::Value::Bool(true) + || grid_flag == &serde_json::Value::Number(serde_json::Number::from(1)), + "unexpected grid flag value: {:?}", + grid_flag ); assert!( - row["created_at"] - .as_str() - .map(|s| !s.is_empty()) - .unwrap_or(false), - "created_at should be rendered as a non-empty string" + grid_row.get("meta").is_some(), + "meta should exist in table_data" ); + + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver.close().await; } #[tokio::test] #[ignore] -async fn test_mssql_metadata_and_ddl_with_special_table_name() { - let host = env::var("MSSQL_HOST").unwrap_or_else(|_| "localhost".to_string()); - let port = env::var("MSSQL_PORT") - .unwrap_or_else(|_| "1433".to_string()) - .parse() - .expect("MSSQL_PORT should be a number"); - let username = env::var("MSSQL_USER").unwrap_or_else(|_| "sa".to_string()); - let password = env::var("MSSQL_PASSWORD").unwrap_or_default(); - let database = env::var("MSSQL_DB").unwrap_or_else(|_| "master".to_string()); +async fn test_mssql_execute_query_reports_affected_rows_for_update_delete() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; - let form = ConnectionForm { + let table_name = "dbpaw_mssql_affected_rows_probe"; + let qualified = format!("[dbo].[{}]", table_name); + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name NVARCHAR(50))", + qualified + )) + .await + .expect("create affected_rows probe table failed"); + + let inserted = driver + .execute_query(format!( + "INSERT INTO {} (id, name) VALUES (1, N'a'), (2, N'b')", + qualified + )) + .await + .expect("insert affected_rows probe rows failed"); + assert_eq!(inserted.row_count, 2); + + let updated = driver + .execute_query(format!( + "UPDATE {} SET name = N'bb' WHERE id = 2", + qualified + )) + .await + .expect("update affected_rows probe row failed"); + assert_eq!(updated.row_count, 1); + + let deleted = driver + .execute_query(format!("DELETE FROM {} WHERE id IN (1, 2)", qualified)) + .await + .expect("delete affected_rows probe rows failed"); + assert_eq!(deleted.row_count, 2); + + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mssql_transaction_commit_and_rollback() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; + + let table_name = "dbpaw_mssql_txn_probe"; + let qualified = format!("[dbo].[{}]", table_name); + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name NVARCHAR(50))", + qualified + )) + .await + .expect("create mssql txn probe table failed"); + + // Test rollback using a single connection from the pool + { + let mut conn = driver.pool.get().await.expect("get connection failed"); + conn.simple_query("BEGIN TRANSACTION") + .await + .expect("begin transaction failed"); + conn.simple_query(&format!( + "INSERT INTO {} (id, name) VALUES (1, N'rolled_back')", + qualified + )) + .await + .expect("insert in rollback tx failed"); + conn.simple_query("ROLLBACK TRANSACTION") + .await + .expect("rollback failed"); + } + + let rolled_back = driver + .execute_query(format!( + "SELECT COUNT(*) AS c FROM {} WHERE id = 1", + qualified + )) + .await + .expect("count after rollback failed"); + let rolled_back_count = scalar_to_i64(&rolled_back.data[0]["c"]); + assert_eq!(rolled_back_count, 0); + + // Test commit using a single connection from the pool + { + let mut conn = driver.pool.get().await.expect("get connection failed"); + conn.simple_query("BEGIN TRANSACTION") + .await + .expect("begin transaction failed"); + conn.simple_query(&format!( + "INSERT INTO {} (id, name) VALUES (2, N'committed')", + qualified + )) + .await + .expect("insert in commit tx failed"); + conn.simple_query("COMMIT TRANSACTION") + .await + .expect("commit failed"); + } + + let committed = driver + .execute_query(format!( + "SELECT COUNT(*) AS c FROM {} WHERE id = 2", + qualified + )) + .await + .expect("count after commit failed"); + let committed_count = scalar_to_i64(&committed.data[0]["c"]); + assert_eq!(committed_count, 1); + + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mssql_error_handling_for_sql_error() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; + + let err = driver + .execute_query("SELECT * FROM __dbpaw_table_not_exists".to_string()) + .await + .expect_err("invalid SQL should return query error"); + assert!( + err.contains("[QUERY_ERROR]") || err.contains("Invalid object name"), + "unexpected error shape: {}", + err + ); +} + +#[tokio::test] +#[ignore] +async fn test_mssql_connection_failure_with_wrong_password() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, mut form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + form.password = Some("dbpaw_wrong_password".to_string()); + + let err = match MssqlDriver::connect(&form).await { + Ok(_) => panic!("wrong password should fail"), + Err(err) => err, + }; + assert!( + err.starts_with("[CONN_FAILED]"), + "unexpected error: {}", + err + ); + assert!(!err.trim().is_empty(), "error message should not be empty"); +} + +#[tokio::test] +#[ignore] +async fn test_mssql_connection_timeout_or_unreachable_host_error() { + let form = dbpaw_lib::models::ConnectionForm { driver: "mssql".to_string(), - host: Some(host), - port: Some(port), - username: Some(username), - password: Some(password), - database: Some(database), + host: Some("203.0.113.1".to_string()), + port: Some(1433), + username: Some("sa".to_string()), + password: Some("Password123".to_string()), + database: Some("master".to_string()), + ssl: Some(false), ..Default::default() }; - let driver = MssqlDriver::connect(&form) + let err = match MssqlDriver::connect(&form).await { + Ok(_) => panic!("unreachable host should fail"), + Err(err) => err, + }; + assert!( + err.starts_with("[CONN_FAILED]"), + "unexpected error: {}", + err + ); + assert!( + err.to_ascii_lowercase().contains("timed out") + || err.to_ascii_lowercase().contains("timeout") + || err.to_ascii_lowercase().contains("network") + || err.to_ascii_lowercase().contains("connection refused") + || err.to_ascii_lowercase().contains("host") + || err.to_ascii_lowercase().contains("unreachable"), + "unexpected timeout/unreachable error: {}", + err + ); +} + +#[tokio::test] +#[ignore] +async fn test_mssql_batch_insert_and_batch_execute_flow() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; + + let table_name = "dbpaw_mssql_batch_probe"; + let qualified = format!("[dbo].[{}]", table_name); + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, category NVARCHAR(20), score INT)", + qualified + )) + .await + .expect("create batch probe table failed"); + + let value_rows: Vec = (1..=50) + .map(|id| { + let category = if id <= 25 { "alpha" } else { "beta" }; + format!("({}, N'{}', {})", id, category, id) + }) + .collect(); + let insert_sql = format!( + "INSERT INTO {} (id, category, score) VALUES {}", + qualified, + value_rows.join(", ") + ); + let inserted = driver + .execute_query(insert_sql) .await - .expect("Failed to connect to SQL Server"); + .expect("batch insert failed"); + assert_eq!(inserted.row_count, 50); - let schema = "dbo"; - let table_name = "dbpaw type-probe"; - let qualified = format!("[{}].[{}]", schema, table_name); + let batch_sqls = vec![ + format!( + "UPDATE {} SET score = score + 100 WHERE id <= 10", + qualified + ), + format!( + "UPDATE {} SET category = N'gamma' WHERE id BETWEEN 30 AND 40", + qualified + ), + format!("DELETE FROM {} WHERE id IN (3, 6, 9, 12, 15)", qualified), + ]; + let mut affected = Vec::new(); + for sql in batch_sqls { + let result = driver + .execute_query(sql) + .await + .expect("batch execute statement failed"); + affected.push(result.row_count); + } + assert_eq!(affected, vec![10, 11, 5]); + + let check_total = driver + .execute_query(format!("SELECT COUNT(*) AS c FROM {}", qualified)) + .await + .expect("count after batch execute failed"); + let total = scalar_to_i64(&check_total.data[0]["c"]); + assert_eq!(total, 45); + + let check_gamma = driver + .execute_query(format!( + "SELECT COUNT(*) AS c FROM {} WHERE category = N'gamma'", + qualified + )) + .await + .expect("count gamma rows failed"); + let gamma = scalar_to_i64(&check_gamma.data[0]["c"]); + assert_eq!(gamma, 11); let _ = driver .execute_query(format!( - "IF OBJECT_ID(N'{}.{}', N'U') IS NOT NULL DROP TABLE {};", - schema, table_name, qualified + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified )) .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mssql_large_text_and_blob_round_trip() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; + let table_name = "dbpaw_mssql_large_field_probe"; + let qualified = format!("[dbo].[{}]", table_name); + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; driver .execute_query(format!( - "CREATE TABLE {} (id INT PRIMARY KEY, note NVARCHAR(50));", + "CREATE TABLE {} (id INT PRIMARY KEY, body NVARCHAR(MAX), payload VARBINARY(MAX))", qualified )) .await - .expect("create special-name table failed"); + .expect("create large field probe table failed"); - let tables = driver.list_tables(None).await.expect("list_tables failed"); - assert!( - tables - .iter() - .any(|t| t.schema == schema && t.name == table_name), - "list_tables should include special-name table" - ); + let large_text = "x".repeat(70000); + let blob_data: Vec = (0..4096).map(|i| (i % 256) as u8).collect(); - let metadata = driver - .get_table_metadata(schema.to_string(), table_name.to_string()) + driver + .execute_query(format!( + "INSERT INTO {} (id, body, payload) VALUES (1, N'{}', 0x{})", + qualified, + large_text, + blob_data + .iter() + .map(|b| format!("{:02x}", b)) + .collect::() + )) .await - .expect("get_table_metadata failed"); + .expect("insert large field probe row failed"); + + let result = driver + .execute_query(format!( + "SELECT body, payload FROM {} WHERE id = 1", + qualified + )) + .await + .expect("select large field probe row failed"); + assert_eq!(result.row_count, 1); + let row = result.data.first().expect("large field row should exist"); + let body = row + .get("body") + .and_then(|v| v.as_str()) + .expect("body should be string"); + assert_eq!(body.len(), 70000); + assert!(row.get("payload").is_some(), "payload should exist"); + + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mssql_concurrent_connections_can_query() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; + + let table_name = "dbpaw_mssql_concurrent_probe"; + let qualified = format!("[dbo].[{}]", table_name); + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT, value NVARCHAR(50))", + qualified + )) + .await + .expect("create concurrent probe table failed"); + driver + .execute_query(format!("INSERT INTO {} VALUES (1, N'test')", qualified)) + .await + .expect("insert concurrent probe row failed"); + driver.close().await; + + let mut handles = Vec::new(); + + for _ in 0..8 { + let task_form = form.clone(); + handles.push(tokio::spawn(async move { + let task_driver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&task_form)).await; + let result = task_driver + .execute_query("SELECT 1 AS ok".to_string()) + .await; + task_driver.close().await; + result + })); + } + + for handle in handles { + let result = handle.await.expect("concurrent mssql task panicked"); + let data = result.expect("concurrent mssql query failed"); + assert_eq!(data.row_count, 1); + let ok = &data.data[0]["ok"]; + let matches = ok == "1" || *ok == serde_json::Value::Number(1.into()); + assert!(matches, "ok should be 1, got {}", ok); + } + + let cleanup_driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; + let _ = cleanup_driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + cleanup_driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mssql_view_can_be_listed_and_queried() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; + + let base_table = "dbpaw_mssql_view_base_probe"; + let view_name = "dbpaw_mssql_view_probe_v"; + let qualified_table = format!("[dbo].[{}]", base_table); + let qualified_view = format!("[dbo].[{}]", view_name); + + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'V') IS NOT NULL DROP VIEW {};", + view_name, qualified_view + )) + .await; + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + base_table, qualified_table + )) + .await; + + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name NVARCHAR(50), score INT)", + qualified_table + )) + .await + .expect("create base table for view failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name, score) VALUES (1, N'alice', 10), (2, N'bob', 20)", + qualified_table + )) + .await + .expect("insert base rows for view failed"); + driver + .execute_query(format!( + "EXEC(N'CREATE VIEW {} AS SELECT id, name FROM {} WHERE score >= 20')", + qualified_view, qualified_table + )) + .await + .expect("create view failed"); + + let tables = driver + .list_tables(Some("dbo".to_string())) + .await + .expect("list_tables failed"); assert!( - metadata - .columns + tables .iter() - .any(|c| c.name == "id" && c.primary_key), - "metadata should include primary key id" + .any(|t| t.name == base_table && t.r#type == "table"), + "list_tables should include base table" ); assert!( - metadata.columns.iter().any(|c| c.name == "note"), - "metadata should include note column" + tables + .iter() + .any(|t| t.name == view_name && t.r#type == "view"), + "list_tables should include view with type=view" ); - let ddl = driver - .get_table_ddl(schema.to_string(), table_name.to_string()) + let view_rows = driver + .execute_query(format!( + "SELECT id, name FROM {} ORDER BY id", + qualified_view + )) .await - .expect("get_table_ddl failed"); - assert!( - ddl.to_uppercase().contains("CREATE TABLE"), - "DDL should contain CREATE TABLE" + .expect("select from view failed"); + assert_eq!(view_rows.row_count, 1); + let row = view_rows.data.first().expect("view row should exist"); + let id_matches = row["id"] == serde_json::Value::Number(2.into()) + || row["id"] == serde_json::Value::String("2".to_string()); + assert!(id_matches, "unexpected id payload: {}", row["id"]); + assert_eq!(row["name"], serde_json::Value::String("bob".to_string())); + + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'V') IS NOT NULL DROP VIEW {};", + view_name, qualified_view + )) + .await; + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + base_table, qualified_table + )) + .await; + driver.close().await; +} + +#[tokio::test] +#[ignore] +async fn test_mssql_prepared_statements_prepare_execute_and_deallocate() { + let docker = (!mssql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mssql_context::mssql_form_from_test_context(docker.as_ref()); + let driver: MssqlDriver = + mssql_context::connect_with_retry(|| MssqlDriver::connect(&form)).await; + + let table_name = "dbpaw_mssql_prepared_stmt_probe"; + let qualified = format!("[dbo].[{}]", table_name); + let _ = driver + .execute_query(format!( + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified + )) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name NVARCHAR(50))", + qualified + )) + .await + .expect("create prepared stmt probe table failed"); + + let prepared_insert_sql = format!("INSERT INTO {} (id, name) VALUES (@P1, @P2)", qualified); + let inserted_a = driver + .execute_query(format!( + "EXEC sp_executesql N'{}', N'@P1 INT, @P2 NVARCHAR(50)', @P1 = 1, @P2 = N'alice'", + prepared_insert_sql.replace("'", "''") + )) + .await + .expect("prepared insert alice failed"); + assert_eq!(inserted_a.row_count, 1); + + let inserted_b = driver + .execute_query(format!( + "EXEC sp_executesql N'{}', N'@P1 INT, @P2 NVARCHAR(50)', @P1 = 2, @P2 = N'bob'", + prepared_insert_sql.replace("'", "''") + )) + .await + .expect("prepared insert bob failed"); + assert_eq!(inserted_b.row_count, 1); + + let prepared_update_sql = format!("UPDATE {} SET name = @P1 WHERE id = @P2", qualified); + let updated = driver + .execute_query(format!( + "EXEC sp_executesql N'{}', N'@P1 NVARCHAR(50), @P2 INT', @P1 = N'alice-updated', @P2 = 1", + prepared_update_sql.replace("'", "''") + )) + .await + .expect("prepared update failed"); + assert_eq!(updated.row_count, 1); + + let prepared_select_sql = format!("SELECT name FROM {} WHERE id = @P1", qualified); + let selected_exec = driver + .execute_query(format!( + "EXEC sp_executesql N'{}', N'@P1 INT', @P1 = 1", + prepared_select_sql.replace("'", "''") + )) + .await + .expect("prepared select failed"); + assert_eq!(selected_exec.row_count, 1); + let selected = driver + .execute_query(format!("SELECT name FROM {} WHERE id = 1", qualified)) + .await + .expect("verify prepared select result failed"); + assert_eq!(selected.row_count, 1); + assert_eq!( + selected.data[0]["name"], + serde_json::Value::String("alice-updated".to_string()) ); + let verify = driver + .execute_query(format!("SELECT COUNT(*) AS c FROM {}", qualified)) + .await + .expect("verify prepared writes failed"); + let total = scalar_to_i64(&verify.data[0]["c"]); + assert_eq!(total, 2); + let _ = driver .execute_query(format!( - "IF OBJECT_ID(N'{}.{}', N'U') IS NOT NULL DROP TABLE {};", - schema, table_name, qualified + "IF OBJECT_ID(N'dbo.{}', N'U') IS NOT NULL DROP TABLE {};", + table_name, qualified )) .await; + driver.close().await; } diff --git a/src-tauri/tests/mysql_command_integration.rs b/src-tauri/tests/mysql_command_integration.rs new file mode 100644 index 0000000..a820a1f --- /dev/null +++ b/src-tauri/tests/mysql_command_integration.rs @@ -0,0 +1,326 @@ +#[path = "common/mysql_context.rs"] +mod mysql_context; + +use dbpaw_lib::commands::{connection, metadata, query}; +use dbpaw_lib::db::drivers::mysql::MysqlDriver; +use dbpaw_lib::db::drivers::DatabaseDriver; +use dbpaw_lib::models::ConnectionForm; +use std::time::{SystemTime, UNIX_EPOCH}; +use testcontainers::clients::Cli; +use tokio::time::{sleep, Duration}; + +fn unique_table_name(prefix: &str) -> String { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after unix epoch") + .as_millis(); + format!("{}_{}", prefix, millis) +} + +async fn wait_until_mysql_ready(form: &ConnectionForm) { + let mut last_error = String::new(); + for _ in 0..45 { + let probe = form.clone(); + match connection::test_connection_ephemeral(probe).await { + Ok(_) => return, + Err(err) => { + last_error = err; + sleep(Duration::from_secs(1)).await; + } + } + } + panic!("mysql is not ready for command tests: {last_error}"); +} + +async fn prepare_query_test_table(form: &ConnectionForm, table: &str) { + let driver = MysqlDriver::connect(form) + .await + .expect("failed to connect mysql driver"); + + driver + .execute_query(format!("DROP TABLE IF EXISTS {}", table)) + .await + .expect("drop table should succeed"); + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name VARCHAR(64))", + table + )) + .await + .expect("create table should succeed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name) VALUES (1, 'DbPaw')", + table + )) + .await + .expect("insert row should succeed"); + driver.close().await; +} + +async fn cleanup_table(form: &ConnectionForm, table: &str) { + let driver = MysqlDriver::connect(form) + .await + .expect("failed to connect mysql driver for cleanup"); + driver + .execute_query(format!("DROP TABLE IF EXISTS {}", table)) + .await + .expect("drop table should succeed"); + driver.close().await; +} + +async fn execute_by_conn_sql( + form: ConnectionForm, + sql: String, +) -> Result { + query::execute_by_conn_direct(form, sql).await +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_test_connection_success() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + + let result = connection::test_connection_ephemeral(form) + .await + .expect("test_connection_ephemeral should succeed"); + + assert!(result.success); + assert!(result.latency_ms.is_some()); +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_test_connection_invalid_password_returns_error() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, mut form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let ready_form = form.clone(); + wait_until_mysql_ready(&ready_form).await; + form.password = Some("dbpaw_wrong_password".to_string()); + + let result = connection::test_connection_ephemeral(form).await; + + assert!(result.is_err()); + let error = result.err().unwrap_or_default(); + assert!(!error.trim().is_empty()); +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_list_tables_by_conn_contains_created_table() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let table = unique_table_name("dbpaw_cmd_tables"); + prepare_query_test_table(&form, &table).await; + + let tables = metadata::list_tables_by_conn(form.clone()) + .await + .expect("list_tables_by_conn should succeed"); + + assert!(tables.iter().any(|t| t.name == table)); + cleanup_table(&form, &table).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_list_tables_by_conn_invalid_credentials_returns_error() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, mut form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let ready_form = form.clone(); + wait_until_mysql_ready(&ready_form).await; + form.password = Some("dbpaw_wrong_password".to_string()); + + let result = metadata::list_tables_by_conn(form).await; + + assert!(result.is_err()); + let error = result.err().unwrap_or_default(); + assert!(!error.trim().is_empty()); +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_list_databases_contains_target_db() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let target_db = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + + let databases = connection::list_databases(form) + .await + .expect("list_databases should succeed"); + + assert!(!databases.is_empty()); + assert!(databases.iter().any(|db| db == &target_db)); + assert!(databases.iter().all(|db| !db.trim().is_empty())); +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_list_databases_invalid_credentials_returns_error() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, mut form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let ready_form = form.clone(); + wait_until_mysql_ready(&ready_form).await; + form.password = Some("dbpaw_wrong_password".to_string()); + + let result = connection::list_databases(form).await; + + assert!(result.is_err()); + let error = result.err().unwrap_or_default(); + assert!(!error.trim().is_empty()); +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_execute_by_conn_select_returns_rows() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let table = unique_table_name("dbpaw_cmd_exec_select"); + prepare_query_test_table(&form, &table).await; + + let sql = format!("SELECT id, name FROM {} ORDER BY id", table); + let result = execute_by_conn_sql(form.clone(), sql) + .await + .expect("execute_by_conn should succeed"); + + assert!(result.success); + assert!(result.row_count >= 1); + assert!(!result.data.is_empty()); + let row = result.data.first().expect("result row should exist"); + let name = row.get("name").and_then(|v| v.as_str()); + assert_eq!(name, Some("DbPaw")); + cleanup_table(&form, &table).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_execute_by_conn_invalid_sql_returns_error() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + + let result = execute_by_conn_sql( + form, + "SELECT * FROM __dbpaw_missing_command_table".to_string(), + ) + .await; + + assert!(result.is_err()); + let error = result.err().unwrap_or_default(); + assert!(!error.trim().is_empty()); +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_execute_by_conn_insert_affects_rows() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let table = unique_table_name("dbpaw_cmd_exec_insert"); + + let driver = MysqlDriver::connect(&form) + .await + .expect("failed to connect mysql driver"); + driver + .execute_query(format!("DROP TABLE IF EXISTS {}", table)) + .await + .expect("drop table should succeed"); + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name VARCHAR(64))", + table + )) + .await + .expect("create table should succeed"); + driver.close().await; + + let sql = format!("INSERT INTO {} (id, name) VALUES (1, 'alpha')", table); + let result = execute_by_conn_sql(form.clone(), sql) + .await + .expect("execute_by_conn insert should succeed"); + assert!(result.success); + assert_eq!(result.row_count, 1); + + cleanup_table(&form, &table).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_get_table_data_by_conn_pagination_works() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + + let database = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + let table = unique_table_name("dbpaw_cmd_page"); + + let driver = MysqlDriver::connect(&form) + .await + .expect("failed to connect mysql driver"); + driver + .execute_query(format!("DROP TABLE IF EXISTS {}", table)) + .await + .expect("drop table should succeed"); + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name VARCHAR(64))", + table + )) + .await + .expect("create table should succeed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name) VALUES (1, 'a'), (2, 'b'), (3, 'c')", + table + )) + .await + .expect("insert rows should succeed"); + driver.close().await; + + let page1 = query::get_table_data_by_conn(form.clone(), database.clone(), table.clone(), 1, 2) + .await + .expect("page 1 should succeed"); + let page2 = query::get_table_data_by_conn(form.clone(), database, table.clone(), 2, 2) + .await + .expect("page 2 should succeed"); + + assert_eq!(page1.total, 3); + assert_eq!(page1.limit, 2); + assert_eq!(page1.page, 1); + assert_eq!(page1.data.len(), 2); + assert_eq!(page2.data.len(), 1); + + cleanup_table(&form, &table).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_get_table_data_by_conn_invalid_pagination_returns_error() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + + let database = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + let table = unique_table_name("dbpaw_cmd_invalid_page"); + prepare_query_test_table(&form, &table).await; + + let result = query::get_table_data_by_conn(form.clone(), database, table.clone(), 0, 10).await; + assert!(result.is_err()); + let error = result.err().unwrap_or_default(); + assert!(error.contains("[VALIDATION_ERROR]")); + + cleanup_table(&form, &table).await; +} diff --git a/src-tauri/tests/mysql_integration.rs b/src-tauri/tests/mysql_integration.rs index e23c89c..8cdfff9 100644 --- a/src-tauri/tests/mysql_integration.rs +++ b/src-tauri/tests/mysql_integration.rs @@ -1,39 +1,25 @@ +#[path = "common/mysql_context.rs"] +mod mysql_context; + use dbpaw_lib::db::drivers::mysql::MysqlDriver; use dbpaw_lib::db::drivers::DatabaseDriver; -use dbpaw_lib::models::ConnectionForm; -use std::env; +use testcontainers::clients::Cli; #[tokio::test] #[ignore] async fn test_mysql_integration_flow() { - // Retrieve connection info from environment variables - // Defaults are set for a local MySQL instance often used in development - let host = env::var("MYSQL_HOST").unwrap_or_else(|_| "localhost".to_string()); - let port = env::var("MYSQL_PORT") - .unwrap_or_else(|_| "3306".to_string()) - .parse() - .unwrap(); - let username = env::var("MYSQL_USER").unwrap_or_else(|_| "root".to_string()); - let password = env::var("MYSQL_PASSWORD").unwrap_or_else(|_| "123456".to_string()); - // Use a specific test database if provided, otherwise default to None (which might fail list_tables if implementation depends on it) - // Looking at mysql.rs, list_tables uses self.form.database as default schema. - let database = env::var("MYSQL_DB").ok(); - - println!("Testing MySQL connection to {}:{}", host, port); - - let form = ConnectionForm { - driver: "mysql".to_string(), - host: Some(host), - port: Some(port), - username: Some(username), - password: Some(password), - database: database.clone(), - ..Default::default() - }; + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form.database.clone(); - let driver: MysqlDriver = MysqlDriver::connect(&form) - .await - .expect("Failed to connect"); + println!( + "Testing MySQL connection to {}:{}", + form.host.as_deref().unwrap_or("localhost"), + form.port.unwrap_or(3306) + ); + + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; // 1. Test Connection // This just runs "SELECT 1" @@ -110,28 +96,15 @@ async fn test_mysql_integration_flow() { #[tokio::test] #[ignore] async fn test_mysql_metadata_and_type_mapping_flow() { - let host = env::var("MYSQL_HOST").unwrap_or_else(|_| "localhost".to_string()); - let port = env::var("MYSQL_PORT") - .unwrap_or_else(|_| "3306".to_string()) - .parse() - .unwrap(); - let username = env::var("MYSQL_USER").unwrap_or_else(|_| "root".to_string()); - let password = env::var("MYSQL_PASSWORD").unwrap_or_else(|_| "123456".to_string()); - let database = env::var("MYSQL_DB").unwrap_or_else(|_| "test_db".to_string()); - - let form = ConnectionForm { - driver: "mysql".to_string(), - host: Some(host), - port: Some(port), - username: Some(username), - password: Some(password), - database: Some(database.clone()), - ..Default::default() - }; + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MYSQL_DB or container default database should be present"); - let driver: MysqlDriver = MysqlDriver::connect(&form) - .await - .expect("Failed to connect"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; let table_name = "dbpaw_type_probe"; let qualified = format!("`{}`.`{}`", database, table_name); @@ -271,28 +244,11 @@ async fn test_mysql_metadata_and_type_mapping_flow() { #[tokio::test] #[ignore] async fn test_mysql_list_databases_and_tables_with_binary_collation_database() { - let host = env::var("MYSQL_HOST").unwrap_or_else(|_| "localhost".to_string()); - let port = env::var("MYSQL_PORT") - .unwrap_or_else(|_| "3306".to_string()) - .parse() - .unwrap(); - let username = env::var("MYSQL_USER").unwrap_or_else(|_| "root".to_string()); - let password = env::var("MYSQL_PASSWORD").unwrap_or_else(|_| "123456".to_string()); - let database = env::var("MYSQL_DB").unwrap_or_else(|_| "test_db".to_string()); - - let form = ConnectionForm { - driver: "mysql".to_string(), - host: Some(host), - port: Some(port), - username: Some(username), - password: Some(password), - database: Some(database), - ..Default::default() - }; + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); - let driver: MysqlDriver = MysqlDriver::connect(&form) - .await - .expect("Failed to connect"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; let probe_db = "dbpaw_bin_probe"; let probe_table = "probe_tbl"; @@ -342,28 +298,15 @@ async fn test_mysql_list_databases_and_tables_with_binary_collation_database() { #[tokio::test] #[ignore] async fn test_mysql_list_tables_with_unicode_table_name() { - let host = env::var("MYSQL_HOST").unwrap_or_else(|_| "localhost".to_string()); - let port = env::var("MYSQL_PORT") - .unwrap_or_else(|_| "3306".to_string()) - .parse() - .unwrap(); - let username = env::var("MYSQL_USER").unwrap_or_else(|_| "root".to_string()); - let password = env::var("MYSQL_PASSWORD").unwrap_or_else(|_| "123456".to_string()); - let database = env::var("MYSQL_DB").unwrap_or_else(|_| "test_db".to_string()); - - let form = ConnectionForm { - driver: "mysql".to_string(), - host: Some(host), - port: Some(port), - username: Some(username), - password: Some(password), - database: Some(database.clone()), - ..Default::default() - }; + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MYSQL_DB or container default database should be present"); - let driver: MysqlDriver = MysqlDriver::connect(&form) - .await - .expect("Failed to connect"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; let table_name = "dbpaw_中文_probe"; let qualified = format!("`{}`.`{}`", database, table_name); @@ -394,3 +337,890 @@ async fn test_mysql_list_tables_with_unicode_table_name() { .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) .await; } + +#[tokio::test] +#[ignore] +async fn test_mysql_get_table_data_supports_pagination_sort_filter_and_order_by() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MYSQL_DB or container default database should be present"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let table_name = "dbpaw_grid_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name VARCHAR(20), score INT)", + qualified + )) + .await + .expect("create grid probe table failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name, score) VALUES \ + (1, 'alpha', 10), (2, 'beta', 20), (3, 'gamma', 30), (4, 'delta', 40)", + qualified + )) + .await + .expect("insert grid probe rows failed"); + + let page1 = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 2, + Some("score".to_string()), + Some("desc".to_string()), + None, + None, + ) + .await + .expect("get_table_data for page1 failed"); + assert_eq!(page1.total, 4); + assert_eq!(page1.data.len(), 2); + assert_eq!( + page1.data[0]["name"], + serde_json::Value::String("delta".to_string()) + ); + + let filtered = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 10, + None, + None, + Some("score >= 20".to_string()), + None, + ) + .await + .expect("get_table_data with filter failed"); + assert_eq!(filtered.total, 3); + + let ordered = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 1, + Some("id".to_string()), + Some("asc".to_string()), + None, + Some("name DESC".to_string()), + ) + .await + .expect("get_table_data with order_by priority failed"); + assert_eq!(ordered.total, 4); + assert_eq!(ordered.data.len(), 1); + assert_eq!( + ordered.data[0]["name"], + serde_json::Value::String("gamma".to_string()) + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_get_table_data_rejects_invalid_sort_column() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MYSQL_DB or container default database should be present"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let table_name = "dbpaw_invalid_sort_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name VARCHAR(20))", + qualified + )) + .await + .expect("create invalid sort probe table failed"); + + let result = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 10, + Some("id desc".to_string()), + Some("desc".to_string()), + None, + None, + ) + .await; + let err = result.expect_err("invalid sort column should return an error"); + assert!( + err.contains("[VALIDATION_ERROR] Invalid sort column name"), + "unexpected error: {}", + err + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_table_structure_and_schema_overview() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MYSQL_DB or container default database should be present"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let table_name = "dbpaw_overview_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, label VARCHAR(30) NOT NULL)", + qualified + )) + .await + .expect("create overview probe table failed"); + + let structure = driver + .get_table_structure(database.clone(), table_name.to_string()) + .await + .expect("get_table_structure failed"); + assert!( + structure + .columns + .iter() + .any(|c| c.name == "id" && c.primary_key), + "table structure should include primary key id" + ); + assert!( + structure.columns.iter().any(|c| c.name == "label"), + "table structure should include label column" + ); + + let overview = driver + .get_schema_overview(Some(database.clone())) + .await + .expect("get_schema_overview failed"); + assert!( + overview + .tables + .iter() + .any(|t| t.schema == database && t.name == table_name), + "schema overview should include {}.{}", + database, + table_name + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_metadata_includes_indexes_and_foreign_keys() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MYSQL_DB or container default database should be present"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let parent = "dbpaw_parent_meta_probe"; + let child = "dbpaw_child_meta_probe"; + let parent_qualified = format!("`{}`.`{}`", database, parent); + let child_qualified = format!("`{}`.`{}`", database, child); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", child_qualified)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", parent_qualified)) + .await; + + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY)", + parent_qualified + )) + .await + .expect("create parent table failed"); + driver + .execute_query(format!( + "CREATE TABLE {} (\ + id INT PRIMARY KEY, \ + parent_id INT NOT NULL, \ + name VARCHAR(30), \ + INDEX idx_child_name (name), \ + CONSTRAINT fk_child_parent FOREIGN KEY (parent_id) REFERENCES {}(id)\ + )", + child_qualified, parent_qualified + )) + .await + .expect("create child table with fk/index failed"); + + let metadata = driver + .get_table_metadata(database.clone(), child.to_string()) + .await + .expect("get_table_metadata for child failed"); + assert!( + metadata + .indexes + .iter() + .any(|i| i.name == "idx_child_name" && i.columns.contains(&"name".to_string())), + "metadata should include idx_child_name index" + ); + assert!( + metadata.foreign_keys.iter().any(|fk| { + fk.name == "fk_child_parent" + && fk.column == "parent_id" + && fk.referenced_table == parent + && fk.referenced_column == "id" + }), + "metadata should include fk_child_parent" + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", child_qualified)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", parent_qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_boolean_and_json_type_mapping_regression() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MYSQL_DB or container default database should be present"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let table_name = "dbpaw_bool_json_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, flag BOOLEAN, meta JSON)", + qualified + )) + .await + .expect("create bool/json probe table failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, flag, meta) VALUES (1, 1, '{{\"tier\": \"gold\"}}')", + qualified + )) + .await + .expect("insert bool/json probe row failed"); + + let query_result = driver + .execute_query(format!("SELECT flag, meta FROM {} WHERE id = 1", qualified)) + .await + .expect("select bool/json row failed"); + assert_eq!(query_result.row_count, 1); + let query_row = query_result.data.first().expect("query row should exist"); + assert!(query_row.get("flag").is_some(), "flag should exist"); + assert!(query_row.get("meta").is_some(), "meta should exist"); + + let table_data = driver + .get_table_data( + database.clone(), + table_name.to_string(), + 1, + 10, + None, + None, + None, + None, + ) + .await + .expect("get_table_data for bool/json table failed"); + assert_eq!(table_data.total, 1); + let grid_row = table_data + .data + .first() + .expect("table data row should exist"); + assert!( + grid_row.get("flag").is_some(), + "flag should exist in table_data" + ); + assert!( + grid_row.get("meta").is_some(), + "meta should exist in table_data" + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_transaction_commit_and_rollback() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MYSQL_DB or container default database should be present"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let table_name = "dbpaw_txn_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name VARCHAR(30))", + qualified + )) + .await + .expect("create txn probe table failed"); + + let mut rollback_tx = driver.pool.begin().await.expect("begin rollback tx failed"); + sqlx::query(&format!( + "INSERT INTO {} (id, name) VALUES (1, 'rolled_back')", + qualified + )) + .execute(&mut *rollback_tx) + .await + .expect("insert in rollback tx failed"); + rollback_tx.rollback().await.expect("rollback tx failed"); + + let rolled_back = driver + .execute_query(format!( + "SELECT COUNT(*) AS c FROM {} WHERE id = 1", + qualified + )) + .await + .expect("count after rollback failed"); + assert_eq!(rolled_back.row_count, 1); + let rolled_back_count = rolled_back.data[0]["c"] + .as_str() + .expect("rollback count should be string"); + assert_eq!(rolled_back_count, "0"); + + let mut commit_tx = driver.pool.begin().await.expect("begin commit tx failed"); + sqlx::query(&format!( + "INSERT INTO {} (id, name) VALUES (2, 'committed')", + qualified + )) + .execute(&mut *commit_tx) + .await + .expect("insert in commit tx failed"); + commit_tx.commit().await.expect("commit tx failed"); + + let committed = driver + .execute_query(format!( + "SELECT COUNT(*) AS c FROM {} WHERE id = 2", + qualified + )) + .await + .expect("count after commit failed"); + assert_eq!(committed.row_count, 1); + let committed_count = committed.data[0]["c"] + .as_str() + .expect("commit count should be string"); + assert_eq!(committed_count, "1"); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_execute_query_reports_affected_rows_for_update_delete() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MYSQL_DB or container default database should be present"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let table_name = "dbpaw_affected_rows_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name VARCHAR(30))", + qualified + )) + .await + .expect("create affected_rows probe table failed"); + + let inserted = driver + .execute_query(format!( + "INSERT INTO {} (id, name) VALUES (1, 'a'), (2, 'b')", + qualified + )) + .await + .expect("insert affected_rows probe rows failed"); + assert_eq!(inserted.row_count, 2); + + let updated = driver + .execute_query(format!("UPDATE {} SET name = 'bb' WHERE id = 2", qualified)) + .await + .expect("update affected_rows probe row failed"); + assert_eq!(updated.row_count, 1); + + let deleted = driver + .execute_query(format!("DELETE FROM {} WHERE id IN (1, 2)", qualified)) + .await + .expect("delete affected_rows probe rows failed"); + assert_eq!(deleted.row_count, 2); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_large_text_and_blob_round_trip() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MYSQL_DB or container default database should be present"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let table_name = "dbpaw_large_field_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, body LONGTEXT, payload LONGBLOB)", + qualified + )) + .await + .expect("create large field probe table failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, body, payload) VALUES (1, REPEAT('x', 70000), UNHEX(REPEAT('AB', 2048)))", + qualified + )) + .await + .expect("insert large field probe row failed"); + + let result = driver + .execute_query(format!( + "SELECT body, payload FROM {} WHERE id = 1", + qualified + )) + .await + .expect("select large field probe row failed"); + assert_eq!(result.row_count, 1); + let row = result.data.first().expect("large field row should exist"); + let body = row + .get("body") + .and_then(|v| v.as_str()) + .expect("body should be string"); + assert_eq!(body.len(), 70000); + assert!(row.get("payload").is_some(), "payload should exist"); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_error_handling_for_sql_error() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let err = driver + .execute_query("SELECT * FROM __dbpaw_table_not_exists".to_string()) + .await + .expect_err("invalid SQL should return query error"); + assert!( + err.contains("[QUERY_ERROR]"), + "unexpected error shape: {}", + err + ); +} + +#[tokio::test] +#[ignore] +async fn test_mysql_concurrent_connections_can_query() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let mut handles = Vec::new(); + + for _ in 0..8 { + let task_form = form.clone(); + handles.push(tokio::spawn(async move { + let driver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&task_form)).await; + driver.execute_query("SELECT 1 AS ok".to_string()).await + })); + } + + for handle in handles { + let result = handle.await.expect("concurrent mysql task panicked"); + let data = result.expect("concurrent mysql query failed"); + assert_eq!(data.row_count, 1); + let ok = data.data[0]["ok"] + .as_str() + .expect("ok should be a stringified scalar"); + assert_eq!(ok, "1"); + } +} + +#[tokio::test] +#[ignore] +async fn test_mysql_view_can_be_listed_and_queried() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MYSQL_DB or container default database should be present"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let base_table = "dbpaw_view_base_probe"; + let view_name = "dbpaw_view_probe_v"; + let qualified_table = format!("`{}`.`{}`", database, base_table); + let qualified_view = format!("`{}`.`{}`", database, view_name); + + let _ = driver + .execute_query(format!("DROP VIEW IF EXISTS {}", qualified_view)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified_table)) + .await; + + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name VARCHAR(30), score INT)", + qualified_table + )) + .await + .expect("create base table for view failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name, score) VALUES (1, 'alice', 10), (2, 'bob', 20)", + qualified_table + )) + .await + .expect("insert base rows for view failed"); + driver + .execute_query(format!( + "CREATE VIEW {} AS SELECT id, name FROM {} WHERE score >= 20", + qualified_view, qualified_table + )) + .await + .expect("create view failed"); + + let tables = driver + .list_tables(Some(database.clone())) + .await + .expect("list_tables failed"); + assert!( + tables + .iter() + .any(|t| t.name == base_table && t.r#type == "table"), + "list_tables should include base table" + ); + assert!( + tables + .iter() + .any(|t| t.name == view_name && t.r#type == "view"), + "list_tables should include view with type=view" + ); + + let view_rows = driver + .execute_query(format!( + "SELECT id, name FROM {} ORDER BY id", + qualified_view + )) + .await + .expect("select from view failed"); + assert_eq!(view_rows.row_count, 1); + let row = view_rows.data.first().expect("view row should exist"); + let id_matches = row["id"] == serde_json::Value::Number(2.into()) + || row["id"] == serde_json::Value::String("2".to_string()); + assert!(id_matches, "unexpected id payload: {}", row["id"]); + assert_eq!(row["name"], serde_json::Value::String("bob".to_string())); + + let _ = driver + .execute_query(format!("DROP VIEW IF EXISTS {}", qualified_view)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified_table)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_connection_failure_with_wrong_password() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, mut form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + form.password = Some("dbpaw_wrong_password".to_string()); + + let err = match MysqlDriver::connect(&form).await { + Ok(_) => panic!("wrong password should fail"), + Err(err) => err, + }; + assert!( + err.starts_with("[CONN_FAILED]"), + "unexpected error: {}", + err + ); + assert!(!err.trim().is_empty(), "error message should not be empty"); +} + +#[tokio::test] +#[ignore] +async fn test_mysql_connection_timeout_or_unreachable_host_error() { + let form = dbpaw_lib::models::ConnectionForm { + driver: "mysql".to_string(), + host: Some("203.0.113.1".to_string()), + port: Some(3306), + username: Some("root".to_string()), + password: Some("123456".to_string()), + database: Some("test_db".to_string()), + ssl: Some(false), + ..Default::default() + }; + + let err = match MysqlDriver::connect(&form).await { + Ok(_) => panic!("unreachable host should fail"), + Err(err) => err, + }; + assert!( + err.starts_with("[CONN_FAILED]"), + "unexpected error: {}", + err + ); + assert!( + err.contains("could not reach the server") + || err.to_ascii_lowercase().contains("timed out") + || err.to_ascii_lowercase().contains("timeout") + || err.to_ascii_lowercase().contains("network unreachable"), + "unexpected timeout/unreachable error: {}", + err + ); +} + +#[tokio::test] +#[ignore] +async fn test_mysql_batch_insert_and_batch_execute_flow() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MYSQL_DB or container default database should be present"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let table_name = "dbpaw_batch_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, category VARCHAR(20), score INT)", + qualified + )) + .await + .expect("create batch probe table failed"); + + let value_rows: Vec = (1..=50) + .map(|id| { + let category = if id <= 25 { "alpha" } else { "beta" }; + format!("({}, '{}', {})", id, category, id) + }) + .collect(); + let insert_sql = format!( + "INSERT INTO {} (id, category, score) VALUES {}", + qualified, + value_rows.join(", ") + ); + let inserted = driver + .execute_query(insert_sql) + .await + .expect("batch insert failed"); + assert_eq!(inserted.row_count, 50); + + let batch_sqls = vec![ + format!( + "UPDATE {} SET score = score + 100 WHERE id <= 10", + qualified + ), + format!( + "UPDATE {} SET category = 'gamma' WHERE id BETWEEN 30 AND 40", + qualified + ), + format!("DELETE FROM {} WHERE id IN (3, 6, 9, 12, 15)", qualified), + ]; + let mut affected = Vec::new(); + for sql in batch_sqls { + let result = driver + .execute_query(sql) + .await + .expect("batch execute statement failed"); + affected.push(result.row_count); + } + assert_eq!(affected, vec![10, 11, 5]); + + let check_total = driver + .execute_query(format!("SELECT COUNT(*) AS c FROM {}", qualified)) + .await + .expect("count after batch execute failed"); + let total = check_total.data[0]["c"] + .as_str() + .expect("count should be string") + .parse::() + .expect("count should be numeric"); + assert_eq!(total, 45); + + let check_gamma = driver + .execute_query(format!( + "SELECT COUNT(*) AS c FROM {} WHERE category = 'gamma'", + qualified + )) + .await + .expect("count gamma rows failed"); + let gamma = check_gamma.data[0]["c"] + .as_str() + .expect("gamma count should be string") + .parse::() + .expect("gamma count should be numeric"); + assert_eq!(gamma, 11); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_prepared_statements_prepare_execute_and_deallocate() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + let database = form + .database + .clone() + .expect("MYSQL_DB or container default database should be present"); + let driver: MysqlDriver = + mysql_context::connect_with_retry(|| MysqlDriver::connect(&form)).await; + + let table_name = "dbpaw_prepared_stmt_probe"; + let qualified = format!("`{}`.`{}`", database, table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name VARCHAR(30))", + qualified + )) + .await + .expect("create prepared stmt probe table failed"); + + let mut conn = driver + .pool + .acquire() + .await + .expect("acquire mysql pooled connection failed"); + let prepared_insert_sql = format!("INSERT INTO {} (id, name) VALUES (?, ?)", qualified); + let insert_a = sqlx::query(&prepared_insert_sql) + .bind(1_i64) + .bind("alice") + .execute(&mut *conn) + .await + .expect("prepared insert alice failed"); + assert_eq!(insert_a.rows_affected(), 1); + let insert_b = sqlx::query(&prepared_insert_sql) + .bind(2_i64) + .bind("bob") + .execute(&mut *conn) + .await + .expect("prepared insert bob failed"); + assert_eq!(insert_b.rows_affected(), 1); + + let prepared_update_sql = format!("UPDATE {} SET name = ? WHERE id = ?", qualified); + let updated = sqlx::query(&prepared_update_sql) + .bind("alice-updated") + .bind(1_i64) + .execute(&mut *conn) + .await + .expect("prepared update failed"); + assert_eq!(updated.rows_affected(), 1); + + let prepared_select_sql = format!("SELECT name FROM {} WHERE id = ?", qualified); + let selected_name: String = sqlx::query_scalar(&prepared_select_sql) + .bind(1_i64) + .fetch_one(&mut *conn) + .await + .expect("prepared select failed"); + assert_eq!(selected_name, "alice-updated"); + drop(conn); + + let verify = driver + .execute_query(format!("SELECT COUNT(*) AS c FROM {}", qualified)) + .await + .expect("verify prepared writes failed"); + let total = verify.data[0]["c"] + .as_str() + .expect("verify count should be string") + .parse::() + .expect("verify count should parse"); + assert_eq!(total, 2); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} diff --git a/src-tauri/tests/mysql_stateful_command_integration.rs b/src-tauri/tests/mysql_stateful_command_integration.rs new file mode 100644 index 0000000..002d50a --- /dev/null +++ b/src-tauri/tests/mysql_stateful_command_integration.rs @@ -0,0 +1,846 @@ +#[path = "common/mysql_context.rs"] +mod mysql_context; + +use dbpaw_lib::commands::connection::{self, CreateDatabasePayload}; +use dbpaw_lib::commands::{ai, query, storage, transfer}; +use dbpaw_lib::commands::metadata; +use dbpaw_lib::ai::types::AiChatRequest; +use dbpaw_lib::db::drivers::mysql::MysqlDriver; +use dbpaw_lib::db::drivers::DatabaseDriver; +use dbpaw_lib::db::local::LocalDb; +use dbpaw_lib::models::{AiProviderForm, ConnectionForm}; +use dbpaw_lib::state::AppState; +use std::fs; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; +use testcontainers::clients::Cli; +use tokio::time::{sleep, Duration}; + +fn unique_name(prefix: &str) -> String { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after unix epoch") + .as_millis(); + format!("{}_{}", prefix, millis) +} + +async fn wait_until_mysql_ready(form: &ConnectionForm) { + let mut last_error = String::new(); + for _ in 0..45 { + let probe = form.clone(); + match connection::test_connection_ephemeral(probe).await { + Ok(_) => return, + Err(err) => { + last_error = err; + sleep(Duration::from_secs(1)).await; + } + } + } + panic!("mysql is not ready for stateful command tests: {last_error}"); +} + +async fn init_state_with_local_db() -> AppState { + let state = AppState::new(); + let local_db_dir = std::env::temp_dir().join(unique_name("dbpaw_localdb_stateful_it")); + let db = LocalDb::init_with_app_dir(&local_db_dir) + .await + .expect("failed to initialize local db"); + let mut lock = state.local_db.lock().await; + *lock = Some(Arc::new(db)); + drop(lock); + state +} + +async fn create_mysql_connection_for_state( + state: &AppState, + base_form: &ConnectionForm, + suffix: &str, +) -> i64 { + let mut form = base_form.clone(); + form.name = Some(format!("mysql-stateful-{suffix}")); + let created = connection::create_connection_direct(state, form) + .await + .expect("create_connection should succeed"); + created.id +} + +async fn drop_database_if_exists(form: &ConnectionForm, db_name: &str) { + let driver = MysqlDriver::connect(form) + .await + .expect("failed to connect mysql driver for cleanup"); + let _ = driver + .execute_query(format!("DROP DATABASE IF EXISTS `{}`", db_name)) + .await; + driver.close().await; +} + +async fn prepare_metadata_fixture( + form: &ConnectionForm, + schema: &str, + parent_table: &str, + child_table: &str, +) { + let driver = MysqlDriver::connect(form) + .await + .expect("failed to connect mysql driver for metadata fixture"); + let parent_qualified = format!("`{}`.`{}`", schema, parent_table); + let child_qualified = format!("`{}`.`{}`", schema, child_table); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", child_qualified)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", parent_qualified)) + .await; + + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, code VARCHAR(30))", + parent_qualified + )) + .await + .expect("create metadata parent table should succeed"); + driver + .execute_query(format!( + "CREATE TABLE {} (\ + id INT PRIMARY KEY, \ + parent_id INT NOT NULL, \ + name VARCHAR(64), \ + INDEX idx_child_name (name), \ + CONSTRAINT fk_child_parent FOREIGN KEY (parent_id) REFERENCES {}(id)\ + )", + child_qualified, parent_qualified + )) + .await + .expect("create metadata child table should succeed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, code) VALUES (1, 'p1')", + parent_qualified + )) + .await + .expect("insert parent row should succeed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, parent_id, name) VALUES (10, 1, 'child-a')", + child_qualified + )) + .await + .expect("insert child row should succeed"); + driver.close().await; +} + +async fn cleanup_metadata_fixture(form: &ConnectionForm, schema: &str, parent_table: &str, child_table: &str) { + let driver = MysqlDriver::connect(form) + .await + .expect("failed to connect mysql driver for metadata cleanup"); + let parent_qualified = format!("`{}`.`{}`", schema, parent_table); + let child_qualified = format!("`{}`.`{}`", schema, child_table); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", child_qualified)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", parent_qualified)) + .await; + driver.close().await; +} + +async fn get_local_db(state: &AppState) -> Arc { + let lock = state.local_db.lock().await; + lock.as_ref() + .cloned() + .expect("local db should be initialized") +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_create_database_by_id_success() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "create-db-success").await; + + let db_name = unique_name("dbpaw_cmd_created_db"); + let payload = CreateDatabasePayload { + name: db_name.clone(), + if_not_exists: Some(true), + charset: None, + collation: None, + encoding: None, + lc_collate: None, + lc_ctype: None, + }; + + connection::create_database_by_id_direct(&state, conn_id, payload) + .await + .expect("create_database_by_id should succeed"); + let dbs = connection::list_databases_by_id_direct(&state, conn_id) + .await + .expect("list_databases_by_id should succeed"); + assert!(dbs.iter().any(|d| d == &db_name)); + + drop_database_if_exists(&form, &db_name).await; + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_create_database_by_id_if_not_exists_idempotent() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "create-db-idempotent").await; + + let db_name = unique_name("dbpaw_cmd_idempotent_db"); + let payload = CreateDatabasePayload { + name: db_name.clone(), + if_not_exists: Some(true), + charset: None, + collation: None, + encoding: None, + lc_collate: None, + lc_ctype: None, + }; + + connection::create_database_by_id_direct(&state, conn_id, payload.clone()) + .await + .expect("first create_database_by_id should succeed"); + connection::create_database_by_id_direct(&state, conn_id, payload) + .await + .expect("second create_database_by_id should succeed"); + + drop_database_if_exists(&form, &db_name).await; + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_create_database_by_id_invalid_name_returns_validation_error() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "invalid-db-name").await; + + let payload = CreateDatabasePayload { + name: " ".to_string(), + if_not_exists: Some(true), + charset: None, + collation: None, + encoding: None, + lc_collate: None, + lc_ctype: None, + }; + let result = connection::create_database_by_id_direct(&state, conn_id, payload).await; + assert!(result.is_err()); + let err = result.err().unwrap_or_default(); + assert!(err.contains("[VALIDATION_ERROR]")); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_list_databases_by_id_success() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "list-db-success").await; + + let target_db = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + let dbs = connection::list_databases_by_id_direct(&state, conn_id) + .await + .expect("list_databases_by_id should succeed"); + assert!(!dbs.is_empty()); + assert!(dbs.iter().any(|d| d == &target_db)); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_list_databases_by_id_invalid_id_returns_error() { + let state = init_state_with_local_db().await; + let result = connection::list_databases_by_id_direct(&state, -999_999).await; + assert!(result.is_err()); + let err = result.err().unwrap_or_default(); + assert!(!err.trim().is_empty()); +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_connection_crud_flow_create_get_update_delete() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + + let unique = unique_name("dbpaw_cmd_conn"); + let mut create_form = form.clone(); + create_form.name = Some(format!("mysql-{unique}-created")); + let created = connection::create_connection_direct(&state, create_form) + .await + .expect("create_connection should succeed"); + let conn_id = created.id; + + let listed = connection::get_connections_direct(&state) + .await + .expect("get_connections after create should succeed"); + assert!(listed.iter().any(|c| c.id == conn_id)); + + let mut update_form = form.clone(); + update_form.name = Some(format!("mysql-{unique}-updated")); + update_form.database = form.database.clone().or(Some("test_db".to_string())); + let updated = connection::update_connection_direct(&state, conn_id, update_form) + .await + .expect("update_connection should succeed"); + assert_eq!(updated.id, conn_id); + assert_eq!(updated.name, format!("mysql-{unique}-updated")); + + connection::delete_connection_direct(&state, conn_id) + .await + .expect("delete_connection should succeed"); + let listed_after_delete = connection::get_connections_direct(&state) + .await + .expect("get_connections after delete should succeed"); + assert!(!listed_after_delete.iter().any(|c| c.id == conn_id)); +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_get_table_structure_success() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "meta-structure-success").await; + let schema = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + let parent = unique_name("dbpaw_meta_parent"); + let child = unique_name("dbpaw_meta_child"); + prepare_metadata_fixture(&form, &schema, &parent, &child).await; + + let structure = metadata::get_table_structure_direct(&state, conn_id, schema.clone(), child.clone()) + .await + .expect("get_table_structure should succeed"); + assert!(structure.columns.iter().any(|c| c.name == "id")); + assert!(structure.columns.iter().any(|c| c.name == "parent_id")); + + cleanup_metadata_fixture(&form, &schema, &parent, &child).await; + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_get_table_structure_missing_table_returns_error() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "meta-structure-missing").await; + let schema = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + let missing_table = unique_name("dbpaw_meta_missing"); + + let result = metadata::get_table_structure_direct(&state, conn_id, schema, missing_table).await; + assert!(result.is_err()); + let err = result.err().unwrap_or_default(); + assert!(!err.trim().is_empty()); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_get_table_ddl_success() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "meta-ddl-success").await; + let schema = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + let parent = unique_name("dbpaw_meta_parent"); + let child = unique_name("dbpaw_meta_child"); + prepare_metadata_fixture(&form, &schema, &parent, &child).await; + + let ddl = metadata::get_table_ddl_direct( + &state, + conn_id, + Some(schema.clone()), + schema.clone(), + child.clone(), + ) + .await + .expect("get_table_ddl should succeed"); + assert!(ddl.to_uppercase().contains("CREATE TABLE")); + assert!(ddl.contains(&child)); + + cleanup_metadata_fixture(&form, &schema, &parent, &child).await; + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_get_table_metadata_contains_indexes_and_foreign_keys() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "meta-metadata-success").await; + let schema = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + let parent = unique_name("dbpaw_meta_parent"); + let child = unique_name("dbpaw_meta_child"); + prepare_metadata_fixture(&form, &schema, &parent, &child).await; + + let meta = metadata::get_table_metadata_direct( + &state, + conn_id, + Some(schema.clone()), + schema.clone(), + child.clone(), + ) + .await + .expect("get_table_metadata should succeed"); + assert!(meta.indexes.iter().any(|idx| idx.name == "idx_child_name")); + assert!(meta.foreign_keys.iter().any(|fk| fk.column == "parent_id")); + + cleanup_metadata_fixture(&form, &schema, &parent, &child).await; + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_get_schema_overview_contains_target_schema() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "meta-schema-overview").await; + let schema = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + let parent = unique_name("dbpaw_meta_parent"); + let child = unique_name("dbpaw_meta_child"); + prepare_metadata_fixture(&form, &schema, &parent, &child).await; + + let overview = metadata::get_schema_overview_direct( + &state, + conn_id, + Some(schema.clone()), + Some(schema.clone()), + ) + .await + .expect("get_schema_overview should succeed"); + assert!(overview.tables.iter().any(|t| t.schema == schema && t.name == child)); + + cleanup_metadata_fixture(&form, &schema, &parent, &child).await; + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_execute_query_by_id_success() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "query-by-id-success").await; + let schema = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + + let result = query::execute_query_by_id_direct( + &state, + conn_id, + "SELECT 1 AS v".to_string(), + Some(schema), + Some("phase4_success".to_string()), + Some("phase4-qid-success".to_string()), + ) + .await + .expect("execute_query_by_id should succeed"); + assert!(result.success); + assert!(result.row_count >= 1); + assert!(!result.data.is_empty()); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_execute_query_by_id_invalid_sql_returns_error() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "query-by-id-invalid").await; + let schema = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + + let result = query::execute_query_by_id_direct( + &state, + conn_id, + "SELECT * FROM __dbpaw_missing_phase4_table".to_string(), + Some(schema), + Some("phase4_invalid".to_string()), + Some("phase4-qid-invalid".to_string()), + ) + .await; + assert!(result.is_err()); + let err = result.err().unwrap_or_default(); + assert!(!err.trim().is_empty()); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_list_sql_execution_logs_contains_recent_entries() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "query-log-list").await; + let schema = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + + query::execute_query_by_id_direct( + &state, + conn_id, + "SELECT 1 AS phase4_log_probe".to_string(), + Some(schema), + Some("phase4_log_probe".to_string()), + Some("phase4-qid-log".to_string()), + ) + .await + .expect("execute_query_by_id for log probe should succeed"); + + let logs = query::list_sql_execution_logs_direct(&state, Some(20)) + .await + .expect("list_sql_execution_logs should succeed"); + assert!(!logs.is_empty()); + assert!(logs.iter().any(|l| { + l.source.as_deref() == Some("phase4_log_probe") + && l.sql.contains("phase4_log_probe") + && l.success + })); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_cancel_query_non_clickhouse_returns_false() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "query-cancel-non-ch").await; + + let canceled = query::cancel_query_direct( + &state, + conn_id.to_string(), + "phase4-qid-cancel".to_string(), + ) + .await + .expect("cancel_query should return bool"); + assert!(!canceled); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_storage_saved_query_crud_flow() { + let state = init_state_with_local_db().await; + let name = unique_name("saved_query"); + let created = storage::save_query_direct( + &state, + name.clone(), + "SELECT 1".to_string(), + Some("desc".to_string()), + None, + Some("test_db".to_string()), + ) + .await + .expect("save_query should succeed"); + assert_eq!(created.name, name); + + let all = storage::get_saved_queries_direct(&state) + .await + .expect("get_saved_queries should succeed"); + assert!(all.iter().any(|q| q.id == created.id)); + + let updated = storage::update_saved_query_direct( + &state, + created.id, + format!("{}_updated", created.name), + "SELECT 2".to_string(), + Some("desc2".to_string()), + None, + Some("test_db".to_string()), + ) + .await + .expect("update_saved_query should succeed"); + assert_eq!(updated.query, "SELECT 2"); + + storage::delete_saved_query_direct(&state, created.id) + .await + .expect("delete_saved_query should succeed"); + let all_after = storage::get_saved_queries_direct(&state) + .await + .expect("get_saved_queries after delete should succeed"); + assert!(!all_after.iter().any(|q| q.id == created.id)); +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_transfer_export_and_import_minimal_flow() { + let docker = (!mysql_context::should_reuse_local_db()).then(Cli::default); + let (_mysql_container, form) = mysql_context::mysql_form_from_test_context(docker.as_ref()); + wait_until_mysql_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_mysql_connection_for_state(&state, &form, "transfer-minimal").await; + let schema = form + .database + .clone() + .unwrap_or_else(|| "test_db".to_string()); + let table = unique_name("dbpaw_transfer_src"); + let qualified = format!("`{}`.`{}`", schema, table); + let driver = MysqlDriver::connect(&form) + .await + .expect("failed to connect mysql driver"); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name VARCHAR(64))", + qualified + )) + .await + .expect("create transfer src table should succeed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name) VALUES (1, 'a'), (2, 'b')", + qualified + )) + .await + .expect("insert transfer src rows should succeed"); + driver.close().await; + + let base = std::env::temp_dir().join(unique_name("dbpaw_transfer_it")); + fs::create_dir_all(&base).expect("create temp transfer dir should succeed"); + let table_export_path = base.join("table_export.csv"); + let query_export_path = base.join("query_export.json"); + let import_sql_path = base.join("import.sql"); + + let table_export = transfer::export_table_data_direct( + &state, + conn_id, + Some(schema.clone()), + schema.clone(), + table.clone(), + "mysql".to_string(), + transfer::ExportFormat::Csv, + transfer::ExportScope::FullTable, + None, + None, + None, + None, + None, + None, + Some(table_export_path.to_string_lossy().to_string()), + Some(100), + ) + .await + .expect("export_table_data should succeed"); + assert!(table_export.row_count >= 2); + assert!(std::path::Path::new(&table_export.file_path).exists()); + + let query_export = transfer::export_query_result_direct( + &state, + conn_id, + Some(schema.clone()), + format!("SELECT * FROM {} ORDER BY id", qualified), + "mysql".to_string(), + transfer::ExportFormat::Json, + Some(query_export_path.to_string_lossy().to_string()), + ) + .await + .expect("export_query_result should succeed"); + assert!(query_export.row_count >= 2); + assert!(std::path::Path::new(&query_export.file_path).exists()); + + let import_table = unique_name("dbpaw_import_dst"); + let import_sql = format!( + "CREATE TABLE `{}`.`{}` (id INT PRIMARY KEY, name VARCHAR(64));\nINSERT INTO `{}`.`{}` (id, name) VALUES (1, 'x');", + schema, import_table, schema, import_table + ); + fs::write(&import_sql_path, import_sql).expect("write import sql file should succeed"); + let import_result = transfer::import_sql_file_direct( + &state, + conn_id, + Some(schema.clone()), + import_sql_path.to_string_lossy().to_string(), + "mysql".to_string(), + ) + .await + .expect("import_sql_file should succeed"); + assert_eq!(import_result.success_statements, import_result.total_statements); + assert!(import_result.error.is_none()); + + let cleanup_driver = MysqlDriver::connect(&form) + .await + .expect("failed to connect mysql driver for transfer cleanup"); + let _ = cleanup_driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + let _ = cleanup_driver + .execute_query(format!("DROP TABLE IF EXISTS `{}`.`{}`", schema, import_table)) + .await; + cleanup_driver.close().await; + let _ = fs::remove_file(table_export_path); + let _ = fs::remove_file(query_export_path); + let _ = fs::remove_file(import_sql_path); + let _ = fs::remove_dir_all(base); + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_mysql_command_ai_minimal_provider_conversation_and_chat_flow() { + let state = init_state_with_local_db().await; + + let start_without_provider = ai::ai_chat_start_direct( + &state, + AiChatRequest { + request_id: unique_name("ai_start_no_provider"), + provider_id: None, + conversation_id: None, + scenario: "sql_generate".to_string(), + input: "select 1".to_string(), + title: Some("phase5 start".to_string()), + connection_id: None, + database: None, + schema_overview: None, + selected_tables: None, + }, + ) + .await; + assert!(start_without_provider.is_err()); + + let created_provider = ai::ai_create_provider_direct( + &state, + AiProviderForm { + name: unique_name("ai_provider"), + provider_type: Some("openai".to_string()), + base_url: "https://example.invalid/v1".to_string(), + model: "gpt-4o-mini".to_string(), + api_key: Some("sk-test".to_string()), + is_default: Some(true), + enabled: Some(true), + extra_json: None, + }, + ) + .await + .expect("ai_create_provider should succeed"); + let providers = ai::ai_list_providers_direct(&state) + .await + .expect("ai_list_providers should succeed"); + assert!(providers.iter().any(|p| p.id == created_provider.id)); + + let updated_provider = ai::ai_update_provider_direct( + &state, + created_provider.id, + AiProviderForm { + name: format!("{}_updated", created_provider.name), + provider_type: Some("openai".to_string()), + base_url: "https://example.invalid/v1".to_string(), + model: "gpt-4o-mini".to_string(), + api_key: Some("sk-test-2".to_string()), + is_default: Some(true), + enabled: Some(true), + extra_json: None, + }, + ) + .await + .expect("ai_update_provider should succeed"); + assert_eq!(updated_provider.id, created_provider.id); + + ai::ai_set_default_provider_direct(&state, created_provider.id) + .await + .expect("ai_set_default_provider should succeed"); + ai::ai_clear_provider_api_key_direct(&state, "openai".to_string()) + .await + .expect("ai_clear_provider_api_key should succeed"); + + let continue_without_conversation = ai::ai_chat_continue_direct( + &state, + AiChatRequest { + request_id: unique_name("ai_continue_no_conv"), + provider_id: Some(created_provider.id), + conversation_id: None, + scenario: "sql_generate".to_string(), + input: "continue".to_string(), + title: None, + connection_id: None, + database: None, + schema_overview: None, + selected_tables: None, + }, + ) + .await; + assert!(continue_without_conversation.is_err()); + + let db = get_local_db(&state).await; + let conv = db + .create_ai_conversation( + unique_name("ai_conv"), + "sql_generate".to_string(), + None, + None, + ) + .await + .expect("create ai conversation in local db should succeed"); + let conversations = ai::ai_list_conversations_direct(&state, None, None) + .await + .expect("ai_list_conversations should succeed"); + assert!(conversations.iter().any(|c| c.id == conv.id)); + let detail = ai::ai_get_conversation_direct(&state, conv.id) + .await + .expect("ai_get_conversation should succeed"); + assert_eq!(detail.conversation.id, conv.id); + ai::ai_delete_conversation_direct(&state, conv.id) + .await + .expect("ai_delete_conversation should succeed"); + let conversations_after = ai::ai_list_conversations_direct(&state, None, None) + .await + .expect("ai_list_conversations after delete should succeed"); + assert!(!conversations_after.iter().any(|c| c.id == conv.id)); + + ai::ai_delete_provider_direct(&state, created_provider.id) + .await + .expect("ai_delete_provider should succeed"); +} diff --git a/src-tauri/tests/postgres_command_integration.rs b/src-tauri/tests/postgres_command_integration.rs new file mode 100644 index 0000000..9316f26 --- /dev/null +++ b/src-tauri/tests/postgres_command_integration.rs @@ -0,0 +1,337 @@ +#[path = "common/postgres_context.rs"] +mod postgres_context; + +use dbpaw_lib::commands::{connection, metadata, query}; +use dbpaw_lib::db::drivers::postgres::PostgresDriver; +use dbpaw_lib::db::drivers::DatabaseDriver; +use dbpaw_lib::models::ConnectionForm; +use std::time::{SystemTime, UNIX_EPOCH}; +use testcontainers::clients::Cli; + +fn unique_table_name(prefix: &str) -> String { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after unix epoch") + .as_millis(); + format!("{}_{}", prefix, millis) +} + +async fn wait_until_postgres_ready(form: &ConnectionForm) { + let probe = form.clone(); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&probe)).await; + driver + .test_connection() + .await + .expect("postgres should accept connections for command tests"); + driver.close().await; +} + +async fn prepare_query_test_table(form: &ConnectionForm, schema: &str, table: &str) { + let driver = PostgresDriver::connect(form) + .await + .expect("failed to connect postgres driver"); + let qualified = format!("\"{}\".\"{}\"", schema, table); + + driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await + .expect("drop table should succeed"); + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name TEXT)", + qualified + )) + .await + .expect("create table should succeed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name) VALUES (1, 'DbPaw')", + qualified + )) + .await + .expect("insert row should succeed"); + driver.close().await; +} + +async fn cleanup_table(form: &ConnectionForm, schema: &str, table: &str) { + let driver = PostgresDriver::connect(form) + .await + .expect("failed to connect postgres driver for cleanup"); + let qualified = format!("\"{}\".\"{}\"", schema, table); + driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await + .expect("drop table should succeed"); + driver.close().await; +} + +async fn execute_by_conn_sql( + form: ConnectionForm, + sql: String, +) -> Result { + query::execute_by_conn_direct(form, sql).await +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_test_connection_success() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + + let result = connection::test_connection_ephemeral(form) + .await + .expect("test_connection_ephemeral should succeed"); + + assert!(result.success); + assert!(result.latency_ms.is_some()); +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_test_connection_invalid_password_returns_error() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, mut form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let ready_form = form.clone(); + wait_until_postgres_ready(&ready_form).await; + form.password = Some("dbpaw_wrong_password".to_string()); + + let result = connection::test_connection_ephemeral(form).await; + + assert!(result.is_err()); + let error = result.err().unwrap_or_default(); + assert!(!error.trim().is_empty()); +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_list_tables_by_conn_contains_created_table() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, mut form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + form.schema = Some("public".to_string()); + let table = unique_table_name("dbpaw_cmd_tables"); + prepare_query_test_table(&form, "public", &table).await; + + let tables = metadata::list_tables_by_conn(form.clone()) + .await + .expect("list_tables_by_conn should succeed"); + + assert!(tables + .iter() + .any(|t| t.schema == "public" && t.name == table)); + cleanup_table(&form, "public", &table).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_list_tables_by_conn_invalid_credentials_returns_error() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, mut form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let ready_form = form.clone(); + wait_until_postgres_ready(&ready_form).await; + form.password = Some("dbpaw_wrong_password".to_string()); + + let result = metadata::list_tables_by_conn(form).await; + + assert!(result.is_err()); + let error = result.err().unwrap_or_default(); + assert!(!error.trim().is_empty()); +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_list_databases_contains_target_db() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let target_db = form + .database + .clone() + .unwrap_or_else(|| "postgres".to_string()); + + let databases = connection::list_databases(form) + .await + .expect("list_databases should succeed"); + + assert!(!databases.is_empty()); + assert!(databases.iter().any(|db| db == &target_db)); + assert!(databases.iter().all(|db| !db.trim().is_empty())); +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_list_databases_invalid_credentials_returns_error() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, mut form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let ready_form = form.clone(); + wait_until_postgres_ready(&ready_form).await; + form.password = Some("dbpaw_wrong_password".to_string()); + + let result = connection::list_databases(form).await; + + assert!(result.is_err()); + let error = result.err().unwrap_or_default(); + assert!(!error.trim().is_empty()); +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_execute_by_conn_select_returns_rows() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let table = unique_table_name("dbpaw_cmd_exec_select"); + let qualified = format!("\"public\".\"{}\"", table); + prepare_query_test_table(&form, "public", &table).await; + + let sql = format!("SELECT id, name FROM {} ORDER BY id", qualified); + let result = execute_by_conn_sql(form.clone(), sql) + .await + .expect("execute_by_conn should succeed"); + + assert!(result.success); + assert!(result.row_count >= 1); + assert!(!result.data.is_empty()); + let row = result.data.first().expect("result row should exist"); + let name = row.get("name").and_then(|v| v.as_str()); + assert_eq!(name, Some("DbPaw")); + cleanup_table(&form, "public", &table).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_execute_by_conn_invalid_sql_returns_error() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + + let result = execute_by_conn_sql( + form, + "SELECT * FROM __dbpaw_missing_command_table".to_string(), + ) + .await; + + assert!(result.is_err()); + let error = result.err().unwrap_or_default(); + assert!(!error.trim().is_empty()); +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_execute_by_conn_insert_affects_rows() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let table = unique_table_name("dbpaw_cmd_exec_insert"); + let qualified = format!("\"public\".\"{}\"", table); + + let driver = PostgresDriver::connect(&form) + .await + .expect("failed to connect postgres driver"); + driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await + .expect("drop table should succeed"); + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name TEXT)", + qualified + )) + .await + .expect("create table should succeed"); + driver.close().await; + + let sql = format!("INSERT INTO {} (id, name) VALUES (1, 'alpha')", qualified); + let result = execute_by_conn_sql(form.clone(), sql) + .await + .expect("execute_by_conn insert should succeed"); + assert!(result.success); + let verify = execute_by_conn_sql( + form.clone(), + format!("SELECT COUNT(*)::INT AS c FROM {}", qualified), + ) + .await + .expect("verify inserted row should succeed"); + let count = verify + .data + .first() + .and_then(|row| row.get("c")) + .and_then(|v| { + v.as_i64() + .or_else(|| v.as_str().and_then(|s| s.parse::().ok())) + }) + .unwrap_or_default(); + assert_eq!(count, 1); + + cleanup_table(&form, "public", &table).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_get_table_data_by_conn_pagination_works() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + + let schema = "public".to_string(); + let table = unique_table_name("dbpaw_cmd_page"); + let qualified = format!("\"{}\".\"{}\"", schema, table); + + let driver = PostgresDriver::connect(&form) + .await + .expect("failed to connect postgres driver"); + driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await + .expect("drop table should succeed"); + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name TEXT)", + qualified + )) + .await + .expect("create table should succeed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name) VALUES (1, 'a'), (2, 'b'), (3, 'c')", + qualified + )) + .await + .expect("insert rows should succeed"); + driver.close().await; + + let page1 = query::get_table_data_by_conn(form.clone(), schema.clone(), table.clone(), 1, 2) + .await + .expect("page 1 should succeed"); + let page2 = query::get_table_data_by_conn(form.clone(), schema, table.clone(), 2, 2) + .await + .expect("page 2 should succeed"); + + assert_eq!(page1.total, 3); + assert_eq!(page1.limit, 2); + assert_eq!(page1.page, 1); + assert_eq!(page1.data.len(), 2); + assert_eq!(page2.data.len(), 1); + + cleanup_table(&form, "public", &table).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_get_table_data_by_conn_invalid_pagination_returns_error() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + + let schema = "public".to_string(); + let table = unique_table_name("dbpaw_cmd_invalid_page"); + prepare_query_test_table(&form, "public", &table).await; + + let result = query::get_table_data_by_conn(form.clone(), schema, table.clone(), 0, 10).await; + assert!(result.is_err()); + let error = result.err().unwrap_or_default(); + assert!(error.contains("[VALIDATION_ERROR]")); + + cleanup_table(&form, "public", &table).await; +} diff --git a/src-tauri/tests/postgres_integration.rs b/src-tauri/tests/postgres_integration.rs index b1daf54..bb9164d 100644 --- a/src-tauri/tests/postgres_integration.rs +++ b/src-tauri/tests/postgres_integration.rs @@ -1,42 +1,17 @@ +#[path = "common/postgres_context.rs"] +mod postgres_context; + use dbpaw_lib::db::drivers::postgres::PostgresDriver; use dbpaw_lib::db::drivers::DatabaseDriver; -use dbpaw_lib::models::ConnectionForm; -use std::env; +use testcontainers::clients::Cli; #[tokio::test] #[ignore] async fn test_postgres_integration_flow() { - let host = env::var("POSTGRES_HOST") - .or_else(|_| env::var("PG_HOST")) - .unwrap_or_else(|_| "localhost".to_string()); - let port = env::var("POSTGRES_PORT") - .or_else(|_| env::var("PG_PORT")) - .unwrap_or_else(|_| "5432".to_string()) - .parse() - .expect("POSTGRES_PORT should be a number"); - let username = env::var("POSTGRES_USER") - .or_else(|_| env::var("PGUSER")) - .unwrap_or_else(|_| "postgres".to_string()); - let password = env::var("POSTGRES_PASSWORD") - .or_else(|_| env::var("PGPASSWORD")) - .unwrap_or_else(|_| "postgres".to_string()); - let database = env::var("POSTGRES_DB") - .or_else(|_| env::var("PGDATABASE")) - .unwrap_or_else(|_| "postgres".to_string()); - - let form = ConnectionForm { - driver: "postgres".to_string(), - host: Some(host), - port: Some(port), - username: Some(username), - password: Some(password), - database: Some(database.clone()), - ..Default::default() - }; + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); - let driver = PostgresDriver::connect(&form) - .await - .expect("Failed to connect to Postgres"); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&form)).await; driver .test_connection() @@ -171,3 +146,841 @@ async fn test_postgres_integration_flow() { .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) .await; } + +#[tokio::test] +#[ignore] +async fn test_postgres_get_table_data_supports_pagination_sort_filter_and_order_by() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&form)).await; + + let table_name = "dbpaw_pg_grid_probe"; + let qualified = format!("public.{}", table_name); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name TEXT, score INT)", + qualified + )) + .await + .expect("create pg grid probe table failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name, score) VALUES \ + (1, 'alpha', 10), (2, 'beta', 20), (3, 'gamma', 30), (4, 'delta', 40)", + qualified + )) + .await + .expect("insert pg grid rows failed"); + + let page1 = driver + .get_table_data( + "public".to_string(), + table_name.to_string(), + 1, + 2, + Some("score".to_string()), + Some("desc".to_string()), + None, + None, + ) + .await + .expect("get_table_data for page1 failed"); + assert_eq!(page1.total, 4); + assert_eq!(page1.data.len(), 2); + assert_eq!( + page1.data[0]["name"], + serde_json::Value::String("delta".to_string()) + ); + + let filtered = driver + .get_table_data( + "public".to_string(), + table_name.to_string(), + 1, + 10, + None, + None, + Some("score >= 20".to_string()), + None, + ) + .await + .expect("get_table_data with filter failed"); + assert_eq!(filtered.total, 3); + + let ordered = driver + .get_table_data( + "public".to_string(), + table_name.to_string(), + 1, + 1, + Some("id".to_string()), + Some("asc".to_string()), + None, + Some("name DESC".to_string()), + ) + .await + .expect("get_table_data with order_by priority failed"); + assert_eq!(ordered.total, 4); + assert_eq!(ordered.data.len(), 1); + assert_eq!( + ordered.data[0]["name"], + serde_json::Value::String("gamma".to_string()) + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_get_table_data_rejects_invalid_sort_column() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&form)).await; + + let table_name = "dbpaw_pg_invalid_sort_probe"; + let qualified = format!("public.{}", table_name); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!("CREATE TABLE {} (id INT PRIMARY KEY)", qualified)) + .await + .expect("create invalid sort probe table failed"); + + let result = driver + .get_table_data( + "public".to_string(), + table_name.to_string(), + 1, + 10, + Some("id desc".to_string()), + Some("desc".to_string()), + None, + None, + ) + .await; + let err = result.expect_err("invalid sort column should return an error"); + assert!( + err.contains("[VALIDATION_ERROR] Invalid sort column name"), + "unexpected error: {}", + err + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_table_structure_and_schema_overview() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&form)).await; + + let table_name = "dbpaw_pg_overview_probe"; + let qualified = format!("public.{}", table_name); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, label TEXT NOT NULL)", + qualified + )) + .await + .expect("create overview probe table failed"); + + let structure = driver + .get_table_structure("public".to_string(), table_name.to_string()) + .await + .expect("get_table_structure failed"); + assert!( + structure.columns.iter().any(|c| c.name == "id"), + "table structure should include id column" + ); + assert!( + structure.columns.iter().any(|c| c.name == "label"), + "table structure should include label column" + ); + + let overview = driver + .get_schema_overview(Some("public".to_string())) + .await + .expect("get_schema_overview failed"); + assert!( + overview + .tables + .iter() + .any(|t| t.schema == "public" && t.name == table_name), + "schema overview should include public.{}", + table_name + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_metadata_includes_indexes_and_foreign_keys() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&form)).await; + + let parent = "dbpaw_pg_parent_meta_probe"; + let child = "dbpaw_pg_child_meta_probe"; + let parent_qualified = format!("public.{}", parent); + let child_qualified = format!("public.{}", child); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", child_qualified)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", parent_qualified)) + .await; + + driver + .execute_query(format!("CREATE TABLE {} (id INT PRIMARY KEY)", parent_qualified)) + .await + .expect("create parent table failed"); + driver + .execute_query(format!( + "CREATE TABLE {} (\ + id INT PRIMARY KEY, \ + parent_id INT NOT NULL, \ + name TEXT, \ + CONSTRAINT fk_pg_child_parent FOREIGN KEY (parent_id) REFERENCES {}(id)\ + )", + child_qualified, parent_qualified + )) + .await + .expect("create child table with fk failed"); + driver + .execute_query(format!( + "CREATE INDEX idx_pg_child_name ON {} (name)", + child_qualified + )) + .await + .expect("create child index failed"); + + let metadata = driver + .get_table_metadata("public".to_string(), child.to_string()) + .await + .expect("get_table_metadata for child failed"); + assert!( + metadata + .indexes + .iter() + .any(|i| i.name == "idx_pg_child_name" && i.columns.contains(&"name".to_string())), + "metadata should include idx_pg_child_name" + ); + assert!( + metadata.foreign_keys.iter().any(|fk| { + fk.name == "fk_pg_child_parent" + && fk.column == "parent_id" + && fk.referenced_table == parent + && fk.referenced_column == "id" + }), + "metadata should include fk_pg_child_parent" + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", child_qualified)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", parent_qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_boolean_and_json_type_mapping_regression() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&form)).await; + + let table_name = "dbpaw_pg_bool_json_probe"; + let qualified = format!("public.{}", table_name); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, flag BOOLEAN, meta JSONB)", + qualified + )) + .await + .expect("create bool/json probe table failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, flag, meta) VALUES (1, true, '{{\"tier\":\"gold\"}}'::jsonb)", + qualified + )) + .await + .expect("insert bool/json probe row failed"); + + let query_result = driver + .execute_query(format!("SELECT flag, meta FROM {} WHERE id = 1", qualified)) + .await + .expect("select bool/json row failed"); + assert_eq!(query_result.row_count, 1); + let query_row = query_result.data.first().expect("query row should exist"); + assert_eq!(query_row["flag"], serde_json::Value::Bool(true)); + assert!( + query_row.get("meta").is_some(), + "meta should exist in query result" + ); + + let table_data = driver + .get_table_data( + "public".to_string(), + table_name.to_string(), + 1, + 10, + None, + None, + None, + None, + ) + .await + .expect("get_table_data for bool/json table failed"); + assert_eq!(table_data.total, 1); + let grid_row = table_data.data.first().expect("table data row should exist"); + assert_eq!(grid_row["flag"], serde_json::Value::Bool(true)); + assert!( + grid_row.get("meta").is_some(), + "meta should exist in table_data" + ); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_transaction_commit_and_rollback() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&form)).await; + + let table_name = "dbpaw_pg_txn_probe"; + let qualified = format!("public.{}", table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name TEXT)", + qualified + )) + .await + .expect("create pg txn probe table failed"); + + let mut rollback_tx = driver.pool.begin().await.expect("begin rollback tx failed"); + sqlx::query(&format!( + "INSERT INTO {} (id, name) VALUES (1, 'rolled_back')", + qualified + )) + .execute(&mut *rollback_tx) + .await + .expect("insert in rollback tx failed"); + rollback_tx.rollback().await.expect("rollback tx failed"); + + let rolled_back = driver + .execute_query(format!( + "SELECT COUNT(*) AS c FROM {} WHERE id = 1", + qualified + )) + .await + .expect("count after rollback failed"); + let rolled_back_count = rolled_back.data[0]["c"] + .as_str() + .expect("rollback count should be string") + .parse::() + .expect("rollback count should be numeric"); + assert_eq!(rolled_back_count, 0); + + let mut commit_tx = driver.pool.begin().await.expect("begin commit tx failed"); + sqlx::query(&format!( + "INSERT INTO {} (id, name) VALUES (2, 'committed')", + qualified + )) + .execute(&mut *commit_tx) + .await + .expect("insert in commit tx failed"); + commit_tx.commit().await.expect("commit tx failed"); + + let committed = driver + .execute_query(format!( + "SELECT COUNT(*) AS c FROM {} WHERE id = 2", + qualified + )) + .await + .expect("count after commit failed"); + let committed_count = committed.data[0]["c"] + .as_str() + .expect("commit count should be string") + .parse::() + .expect("commit count should be numeric"); + assert_eq!(committed_count, 1); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_execute_query_reports_affected_rows_for_update_delete() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&form)).await; + + let table_name = "dbpaw_pg_affected_rows_probe"; + let qualified = format!("public.{}", table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name TEXT)", + qualified + )) + .await + .expect("create affected_rows probe table failed"); + + let inserted = driver + .execute_query(format!( + "INSERT INTO {} (id, name) VALUES (1, 'a'), (2, 'b')", + qualified + )) + .await + .expect("insert affected_rows probe rows failed"); + assert!(inserted.success); + + let updated = driver + .execute_query(format!("UPDATE {} SET name = 'bb' WHERE id = 2", qualified)) + .await + .expect("update affected_rows probe row failed"); + assert!(updated.success); + + let deleted = driver + .execute_query(format!("DELETE FROM {} WHERE id IN (1, 2)", qualified)) + .await + .expect("delete affected_rows probe rows failed"); + assert!(deleted.success); + + let remain = driver + .execute_query(format!("SELECT COUNT(*) AS c FROM {}", qualified)) + .await + .expect("count after delete should succeed"); + let remain_count = remain.data[0]["c"] + .as_str() + .expect("count should be string") + .parse::() + .expect("count should be numeric"); + assert_eq!(remain_count, 0); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_large_text_and_blob_round_trip() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&form)).await; + + let table_name = "dbpaw_pg_large_field_probe"; + let qualified = format!("public.{}", table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, body TEXT, payload BYTEA)", + qualified + )) + .await + .expect("create large field probe table failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, body, payload) VALUES (1, repeat('x', 70000), decode(repeat('ab', 2048), 'hex'))", + qualified + )) + .await + .expect("insert large field probe row failed"); + + let result = driver + .execute_query(format!( + "SELECT body, payload FROM {} WHERE id = 1", + qualified + )) + .await + .expect("select large field probe row failed"); + assert_eq!(result.row_count, 1); + let row = result.data.first().expect("large field row should exist"); + let body = row + .get("body") + .and_then(|v| v.as_str()) + .expect("body should be string"); + assert_eq!(body.len(), 70000); + assert!(row.get("payload").is_some(), "payload should exist"); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_error_handling_for_sql_error() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&form)).await; + + let err = driver + .execute_query("SELECT * FROM __dbpaw_table_not_exists".to_string()) + .await + .expect_err("invalid SQL should return query error"); + assert!( + err.contains("[QUERY_ERROR]"), + "unexpected error shape: {}", + err + ); +} + +#[tokio::test] +#[ignore] +async fn test_postgres_concurrent_connections_can_query() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let mut handles = Vec::new(); + + for _ in 0..8 { + let task_form = form.clone(); + handles.push(tokio::spawn(async move { + let driver = + postgres_context::connect_with_retry(|| PostgresDriver::connect(&task_form)).await; + driver.execute_query("SELECT 1 AS ok".to_string()).await + })); + } + + for handle in handles { + let result = handle.await.expect("concurrent postgres task panicked"); + let data = result.expect("concurrent postgres query failed"); + assert_eq!(data.row_count, 1); + let ok = &data.data[0]["ok"]; + let matches = ok == "1" || *ok == serde_json::Value::Number(1.into()); + assert!(matches, "ok should be 1, got {}", ok); + } +} + +#[tokio::test] +#[ignore] +async fn test_postgres_view_can_be_listed_and_queried() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&form)).await; + + let base_table = "dbpaw_pg_view_base_probe"; + let view_name = "dbpaw_pg_view_probe_v"; + let qualified_table = format!("public.{}", base_table); + let qualified_view = format!("public.{}", view_name); + + let _ = driver + .execute_query(format!("DROP VIEW IF EXISTS {}", qualified_view)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified_table)) + .await; + + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name TEXT, score INT)", + qualified_table + )) + .await + .expect("create base table for view failed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name, score) VALUES (1, 'alice', 10), (2, 'bob', 20)", + qualified_table + )) + .await + .expect("insert base rows for view failed"); + driver + .execute_query(format!( + "CREATE VIEW {} AS SELECT id, name FROM {} WHERE score >= 20", + qualified_view, qualified_table + )) + .await + .expect("create view failed"); + + let tables = driver + .list_tables(Some("public".to_string())) + .await + .expect("list_tables failed"); + assert!( + tables + .iter() + .any(|t| t.name == base_table && t.r#type == "BASE TABLE"), + "list_tables should include base table" + ); + assert!( + tables + .iter() + .any(|t| t.name == view_name && t.r#type == "VIEW"), + "list_tables should include view with type=VIEW" + ); + + let view_rows = driver + .execute_query(format!( + "SELECT id, name FROM {} ORDER BY id", + qualified_view + )) + .await + .expect("select from view failed"); + assert_eq!(view_rows.row_count, 1); + let row = view_rows.data.first().expect("view row should exist"); + let id_matches = row["id"] == serde_json::Value::Number(2.into()) + || row["id"] == serde_json::Value::String("2".to_string()); + assert!(id_matches, "unexpected id payload: {}", row["id"]); + assert_eq!(row["name"], serde_json::Value::String("bob".to_string())); + + let _ = driver + .execute_query(format!("DROP VIEW IF EXISTS {}", qualified_view)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified_table)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_connection_failure_with_wrong_password() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, mut form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + form.password = Some("dbpaw_wrong_password".to_string()); + + let err = match PostgresDriver::connect(&form).await { + Ok(_) => panic!("wrong password should fail"), + Err(err) => err, + }; + assert!( + err.starts_with("[CONN_FAILED]"), + "unexpected error: {}", + err + ); + assert!(!err.trim().is_empty(), "error message should not be empty"); +} + +#[tokio::test] +#[ignore] +async fn test_postgres_connection_timeout_or_unreachable_host_error() { + let form = dbpaw_lib::models::ConnectionForm { + driver: "postgres".to_string(), + host: Some("203.0.113.1".to_string()), + port: Some(5432), + username: Some("postgres".to_string()), + password: Some("postgres".to_string()), + database: Some("postgres".to_string()), + ssl: Some(false), + ..Default::default() + }; + + let err = match PostgresDriver::connect(&form).await { + Ok(_) => panic!("unreachable host should fail"), + Err(err) => err, + }; + assert!( + err.starts_with("[CONN_FAILED]"), + "unexpected error: {}", + err + ); + assert!( + err.contains("could not reach the server") + || err.to_ascii_lowercase().contains("timed out") + || err.to_ascii_lowercase().contains("timeout") + || err.to_ascii_lowercase().contains("network unreachable"), + "unexpected timeout/unreachable error: {}", + err + ); +} + +#[tokio::test] +#[ignore] +async fn test_postgres_batch_insert_and_batch_execute_flow() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&form)).await; + + let table_name = "dbpaw_pg_batch_probe"; + let qualified = format!("public.{}", table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, category TEXT, score INT)", + qualified + )) + .await + .expect("create batch probe table failed"); + + let value_rows: Vec = (1..=50) + .map(|id| { + let category = if id <= 25 { "alpha" } else { "beta" }; + format!("({}, '{}', {})", id, category, id) + }) + .collect(); + let insert_sql = format!( + "INSERT INTO {} (id, category, score) VALUES {}", + qualified, + value_rows.join(", ") + ); + let inserted = driver + .execute_query(insert_sql) + .await + .expect("batch insert failed"); + assert!(inserted.success); + + let batch_sqls = vec![ + format!( + "UPDATE {} SET score = score + 100 WHERE id <= 10", + qualified + ), + format!( + "UPDATE {} SET category = 'gamma' WHERE id BETWEEN 30 AND 40", + qualified + ), + format!("DELETE FROM {} WHERE id IN (3, 6, 9, 12, 15)", qualified), + ]; + let mut affected = Vec::new(); + for sql in batch_sqls { + let result = driver + .execute_query(sql) + .await + .expect("batch execute statement failed"); + affected.push(result.row_count); + } + assert_eq!(affected.len(), 3); + + let check_total = driver + .execute_query(format!("SELECT COUNT(*) AS c FROM {}", qualified)) + .await + .expect("count after batch execute failed"); + let total = check_total.data[0]["c"] + .as_str() + .expect("count should be string") + .parse::() + .expect("count should be numeric"); + assert_eq!(total, 45); + + let check_gamma = driver + .execute_query(format!( + "SELECT COUNT(*) AS c FROM {} WHERE category = 'gamma'", + qualified + )) + .await + .expect("count gamma rows failed"); + let gamma = check_gamma.data[0]["c"] + .as_str() + .expect("gamma count should be string") + .parse::() + .expect("gamma count should be numeric"); + assert_eq!(gamma, 11); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_prepared_statements_prepare_execute_and_deallocate() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&form)).await; + + let table_name = "dbpaw_pg_prepared_stmt_probe"; + let qualified = format!("public.{}", table_name); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name TEXT)", + qualified + )) + .await + .expect("create prepared stmt probe table failed"); + + let mut conn = driver + .pool + .acquire() + .await + .expect("acquire postgres pooled connection failed"); + let prepared_insert_sql = format!("INSERT INTO {} (id, name) VALUES ($1, $2)", qualified); + let insert_a = sqlx::query(&prepared_insert_sql) + .bind(1_i64) + .bind("alice") + .execute(&mut *conn) + .await + .expect("prepared insert alice failed"); + assert_eq!(insert_a.rows_affected(), 1); + let insert_b = sqlx::query(&prepared_insert_sql) + .bind(2_i64) + .bind("bob") + .execute(&mut *conn) + .await + .expect("prepared insert bob failed"); + assert_eq!(insert_b.rows_affected(), 1); + + let prepared_update_sql = format!("UPDATE {} SET name = $1 WHERE id = $2", qualified); + let updated = sqlx::query(&prepared_update_sql) + .bind("alice-updated") + .bind(1_i64) + .execute(&mut *conn) + .await + .expect("prepared update failed"); + assert_eq!(updated.rows_affected(), 1); + + let prepared_select_sql = format!("SELECT name FROM {} WHERE id = $1", qualified); + let selected_name: String = sqlx::query_scalar(&prepared_select_sql) + .bind(1_i64) + .fetch_one(&mut *conn) + .await + .expect("prepared select failed"); + assert_eq!(selected_name, "alice-updated"); + drop(conn); + + let verify = driver + .execute_query(format!("SELECT COUNT(*) AS c FROM {}", qualified)) + .await + .expect("verify prepared writes failed"); + let total = verify.data[0]["c"] + .as_str() + .expect("verify count should be string") + .parse::() + .expect("verify count should parse"); + assert_eq!(total, 2); + + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; +} diff --git a/src-tauri/tests/postgres_stateful_command_integration.rs b/src-tauri/tests/postgres_stateful_command_integration.rs new file mode 100644 index 0000000..fe1cc68 --- /dev/null +++ b/src-tauri/tests/postgres_stateful_command_integration.rs @@ -0,0 +1,852 @@ +#[path = "common/postgres_context.rs"] +mod postgres_context; + +use dbpaw_lib::ai::types::AiChatRequest; +use dbpaw_lib::commands::connection::{self, CreateDatabasePayload}; +use dbpaw_lib::commands::metadata; +use dbpaw_lib::commands::{ai, query, storage, transfer}; +use dbpaw_lib::db::drivers::postgres::PostgresDriver; +use dbpaw_lib::db::drivers::DatabaseDriver; +use dbpaw_lib::db::local::LocalDb; +use dbpaw_lib::models::{AiProviderForm, ConnectionForm}; +use dbpaw_lib::state::AppState; +use std::fs; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; +use testcontainers::clients::Cli; + +fn unique_name(prefix: &str) -> String { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after unix epoch") + .as_millis(); + format!("{}_{}", prefix, millis) +} + +async fn wait_until_postgres_ready(form: &ConnectionForm) { + let probe = form.clone(); + let driver = postgres_context::connect_with_retry(|| PostgresDriver::connect(&probe)).await; + driver + .test_connection() + .await + .expect("postgres should accept connections for stateful command tests"); + driver.close().await; +} + +async fn init_state_with_local_db() -> AppState { + let state = AppState::new(); + let local_db_dir = std::env::temp_dir().join(unique_name("dbpaw_localdb_stateful_it")); + let db = LocalDb::init_with_app_dir(&local_db_dir) + .await + .expect("failed to initialize local db"); + let mut lock = state.local_db.lock().await; + *lock = Some(Arc::new(db)); + drop(lock); + state +} + +async fn create_postgres_connection_for_state( + state: &AppState, + base_form: &ConnectionForm, + suffix: &str, +) -> i64 { + let mut form = base_form.clone(); + form.name = Some(format!("postgres-stateful-{suffix}")); + let created = connection::create_connection_direct(state, form) + .await + .expect("create_connection should succeed"); + created.id +} + +async fn drop_database_if_exists(form: &ConnectionForm, db_name: &str) { + let driver = PostgresDriver::connect(form) + .await + .expect("failed to connect postgres driver for cleanup"); + let _ = driver + .execute_query(format!("DROP DATABASE IF EXISTS \"{}\"", db_name)) + .await; + driver.close().await; +} + +async fn prepare_metadata_fixture( + form: &ConnectionForm, + schema: &str, + parent_table: &str, + child_table: &str, +) { + let driver = PostgresDriver::connect(form) + .await + .expect("failed to connect postgres driver for metadata fixture"); + let parent_qualified = format!("\"{}\".\"{}\"", schema, parent_table); + let child_qualified = format!("\"{}\".\"{}\"", schema, child_table); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", child_qualified)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", parent_qualified)) + .await; + + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, code TEXT)", + parent_qualified + )) + .await + .expect("create metadata parent table should succeed"); + driver + .execute_query(format!( + "CREATE TABLE {} (\ + id INT PRIMARY KEY, \ + parent_id INT NOT NULL, \ + name TEXT, \ + CONSTRAINT fk_child_parent FOREIGN KEY (parent_id) REFERENCES {}(id)\ + )", + child_qualified, parent_qualified + )) + .await + .expect("create metadata child table should succeed"); + driver + .execute_query(format!( + "CREATE INDEX idx_child_name ON {} (name)", + child_qualified + )) + .await + .expect("create metadata child index should succeed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, code) VALUES (1, 'p1')", + parent_qualified + )) + .await + .expect("insert parent row should succeed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, parent_id, name) VALUES (10, 1, 'child-a')", + child_qualified + )) + .await + .expect("insert child row should succeed"); + driver.close().await; +} + +async fn cleanup_metadata_fixture( + form: &ConnectionForm, + schema: &str, + parent_table: &str, + child_table: &str, +) { + let driver = PostgresDriver::connect(form) + .await + .expect("failed to connect postgres driver for metadata cleanup"); + let parent_qualified = format!("\"{}\".\"{}\"", schema, parent_table); + let child_qualified = format!("\"{}\".\"{}\"", schema, child_table); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", child_qualified)) + .await; + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", parent_qualified)) + .await; + driver.close().await; +} + +async fn get_local_db(state: &AppState) -> Arc { + let lock = state.local_db.lock().await; + lock.as_ref() + .cloned() + .expect("local db should be initialized") +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_create_database_by_id_success() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_postgres_connection_for_state(&state, &form, "create-db-success").await; + + let db_name = unique_name("dbpaw_cmd_created_db"); + let payload = CreateDatabasePayload { + name: db_name.clone(), + if_not_exists: Some(true), + charset: None, + collation: None, + encoding: None, + lc_collate: None, + lc_ctype: None, + }; + + connection::create_database_by_id_direct(&state, conn_id, payload) + .await + .expect("create_database_by_id should succeed"); + let dbs = connection::list_databases_by_id_direct(&state, conn_id) + .await + .expect("list_databases_by_id should succeed"); + assert!(dbs.iter().any(|d| d == &db_name)); + + drop_database_if_exists(&form, &db_name).await; + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_create_database_by_id_if_not_exists_idempotent() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_postgres_connection_for_state(&state, &form, "create-db-idempotent").await; + + let db_name = unique_name("dbpaw_cmd_idempotent_db"); + let payload = CreateDatabasePayload { + name: db_name.clone(), + if_not_exists: Some(true), + charset: None, + collation: None, + encoding: None, + lc_collate: None, + lc_ctype: None, + }; + + connection::create_database_by_id_direct(&state, conn_id, payload.clone()) + .await + .expect("first create_database_by_id should succeed"); + connection::create_database_by_id_direct(&state, conn_id, payload) + .await + .expect("second create_database_by_id should succeed"); + + drop_database_if_exists(&form, &db_name).await; + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_create_database_by_id_invalid_name_returns_validation_error() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_postgres_connection_for_state(&state, &form, "invalid-db-name").await; + + let payload = CreateDatabasePayload { + name: " ".to_string(), + if_not_exists: Some(true), + charset: None, + collation: None, + encoding: None, + lc_collate: None, + lc_ctype: None, + }; + let result = connection::create_database_by_id_direct(&state, conn_id, payload).await; + assert!(result.is_err()); + let err = result.err().unwrap_or_default(); + assert!(err.contains("[VALIDATION_ERROR]")); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_list_databases_by_id_success() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_postgres_connection_for_state(&state, &form, "list-db-success").await; + + let target_db = form + .database + .clone() + .unwrap_or_else(|| "postgres".to_string()); + let dbs = connection::list_databases_by_id_direct(&state, conn_id) + .await + .expect("list_databases_by_id should succeed"); + assert!(!dbs.is_empty()); + assert!(dbs.iter().any(|d| d == &target_db)); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_list_databases_by_id_invalid_id_returns_error() { + let state = init_state_with_local_db().await; + let result = connection::list_databases_by_id_direct(&state, -999_999).await; + assert!(result.is_err()); + let err = result.err().unwrap_or_default(); + assert!(!err.trim().is_empty()); +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_connection_crud_flow_create_get_update_delete() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + + let unique = unique_name("dbpaw_cmd_conn"); + let mut create_form = form.clone(); + create_form.name = Some(format!("postgres-{unique}-created")); + let created = connection::create_connection_direct(&state, create_form) + .await + .expect("create_connection should succeed"); + let conn_id = created.id; + + let listed = connection::get_connections_direct(&state) + .await + .expect("get_connections after create should succeed"); + assert!(listed.iter().any(|c| c.id == conn_id)); + + let mut update_form = form.clone(); + update_form.name = Some(format!("postgres-{unique}-updated")); + update_form.database = form.database.clone().or(Some("postgres".to_string())); + let updated = connection::update_connection_direct(&state, conn_id, update_form) + .await + .expect("update_connection should succeed"); + assert_eq!(updated.id, conn_id); + assert_eq!(updated.name, format!("postgres-{unique}-updated")); + + connection::delete_connection_direct(&state, conn_id) + .await + .expect("delete_connection should succeed"); + let listed_after_delete = connection::get_connections_direct(&state) + .await + .expect("get_connections after delete should succeed"); + assert!(!listed_after_delete.iter().any(|c| c.id == conn_id)); +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_get_table_structure_success() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = + create_postgres_connection_for_state(&state, &form, "meta-structure-success").await; + let schema = "public".to_string(); + let parent = unique_name("dbpaw_meta_parent"); + let child = unique_name("dbpaw_meta_child"); + prepare_metadata_fixture(&form, &schema, &parent, &child).await; + + let structure = + metadata::get_table_structure_direct(&state, conn_id, schema.clone(), child.clone()) + .await + .expect("get_table_structure should succeed"); + assert!(structure.columns.iter().any(|c| c.name == "id")); + assert!(structure.columns.iter().any(|c| c.name == "parent_id")); + + cleanup_metadata_fixture(&form, &schema, &parent, &child).await; + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_get_table_structure_missing_table_returns_error() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = + create_postgres_connection_for_state(&state, &form, "meta-structure-missing").await; + let schema = "public".to_string(); + let missing_table = unique_name("dbpaw_meta_missing"); + + let result = metadata::get_table_structure_direct(&state, conn_id, schema, missing_table).await; + assert!(result.is_err()); + let err = result.err().unwrap_or_default(); + assert!(!err.trim().is_empty()); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_get_table_ddl_success() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_postgres_connection_for_state(&state, &form, "meta-ddl-success").await; + let schema = "public".to_string(); + let database = form + .database + .clone() + .unwrap_or_else(|| "postgres".to_string()); + let parent = unique_name("dbpaw_meta_parent"); + let child = unique_name("dbpaw_meta_child"); + prepare_metadata_fixture(&form, &schema, &parent, &child).await; + + let ddl = metadata::get_table_ddl_direct( + &state, + conn_id, + Some(database), + schema.clone(), + child.clone(), + ) + .await + .expect("get_table_ddl should succeed"); + assert!(ddl.to_uppercase().contains("CREATE TABLE")); + assert!(ddl.contains(&child)); + + cleanup_metadata_fixture(&form, &schema, &parent, &child).await; + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_get_table_metadata_contains_indexes_and_foreign_keys() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = + create_postgres_connection_for_state(&state, &form, "meta-metadata-success").await; + let schema = "public".to_string(); + let database = form + .database + .clone() + .unwrap_or_else(|| "postgres".to_string()); + let parent = unique_name("dbpaw_meta_parent"); + let child = unique_name("dbpaw_meta_child"); + prepare_metadata_fixture(&form, &schema, &parent, &child).await; + + let meta = metadata::get_table_metadata_direct( + &state, + conn_id, + Some(database), + schema.clone(), + child.clone(), + ) + .await + .expect("get_table_metadata should succeed"); + assert!(meta.indexes.iter().any(|idx| idx.name == "idx_child_name")); + assert!(meta.foreign_keys.iter().any(|fk| fk.column == "parent_id")); + + cleanup_metadata_fixture(&form, &schema, &parent, &child).await; + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_get_schema_overview_contains_target_schema() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_postgres_connection_for_state(&state, &form, "meta-schema-overview").await; + let schema = "public".to_string(); + let database = form + .database + .clone() + .unwrap_or_else(|| "postgres".to_string()); + let parent = unique_name("dbpaw_meta_parent"); + let child = unique_name("dbpaw_meta_child"); + prepare_metadata_fixture(&form, &schema, &parent, &child).await; + + let overview = + metadata::get_schema_overview_direct(&state, conn_id, Some(database), Some(schema.clone())) + .await + .expect("get_schema_overview should succeed"); + assert!(overview + .tables + .iter() + .any(|t| t.schema == schema && t.name == child)); + + cleanup_metadata_fixture(&form, &schema, &parent, &child).await; + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_execute_query_by_id_success() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_postgres_connection_for_state(&state, &form, "query-by-id-success").await; + let database = form + .database + .clone() + .unwrap_or_else(|| "postgres".to_string()); + + let result = query::execute_query_by_id_direct( + &state, + conn_id, + "SELECT 1 AS v".to_string(), + Some(database), + Some("phase4_success".to_string()), + Some("phase4-qid-success".to_string()), + ) + .await + .expect("execute_query_by_id should succeed"); + assert!(result.success); + assert!(result.row_count >= 1); + assert!(!result.data.is_empty()); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_execute_query_by_id_invalid_sql_returns_error() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_postgres_connection_for_state(&state, &form, "query-by-id-invalid").await; + let database = form + .database + .clone() + .unwrap_or_else(|| "postgres".to_string()); + + let result = query::execute_query_by_id_direct( + &state, + conn_id, + "SELECT * FROM __dbpaw_missing_phase4_table".to_string(), + Some(database), + Some("phase4_invalid".to_string()), + Some("phase4-qid-invalid".to_string()), + ) + .await; + assert!(result.is_err()); + let err = result.err().unwrap_or_default(); + assert!(!err.trim().is_empty()); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_list_sql_execution_logs_contains_recent_entries() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_postgres_connection_for_state(&state, &form, "query-log-list").await; + let database = form + .database + .clone() + .unwrap_or_else(|| "postgres".to_string()); + + query::execute_query_by_id_direct( + &state, + conn_id, + "SELECT 1 AS phase4_log_probe".to_string(), + Some(database), + Some("phase4_log_probe".to_string()), + Some("phase4-qid-log".to_string()), + ) + .await + .expect("execute_query_by_id for log probe should succeed"); + + let logs = query::list_sql_execution_logs_direct(&state, Some(20)) + .await + .expect("list_sql_execution_logs should succeed"); + assert!(!logs.is_empty()); + assert!(logs.iter().any(|l| { + l.source.as_deref() == Some("phase4_log_probe") + && l.sql.contains("phase4_log_probe") + && l.success + })); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_cancel_query_non_clickhouse_returns_false() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_postgres_connection_for_state(&state, &form, "query-cancel-non-ch").await; + + let canceled = + query::cancel_query_direct(&state, conn_id.to_string(), "phase4-qid-cancel".to_string()) + .await + .expect("cancel_query should return bool"); + assert!(!canceled); + + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_storage_saved_query_crud_flow() { + let state = init_state_with_local_db().await; + let name = unique_name("saved_query"); + let created = storage::save_query_direct( + &state, + name.clone(), + "SELECT 1".to_string(), + Some("desc".to_string()), + None, + Some("postgres".to_string()), + ) + .await + .expect("save_query should succeed"); + assert_eq!(created.name, name); + + let all = storage::get_saved_queries_direct(&state) + .await + .expect("get_saved_queries should succeed"); + assert!(all.iter().any(|q| q.id == created.id)); + + let updated = storage::update_saved_query_direct( + &state, + created.id, + format!("{}_updated", created.name), + "SELECT 2".to_string(), + Some("desc2".to_string()), + None, + Some("postgres".to_string()), + ) + .await + .expect("update_saved_query should succeed"); + assert_eq!(updated.query, "SELECT 2"); + + storage::delete_saved_query_direct(&state, created.id) + .await + .expect("delete_saved_query should succeed"); + let all_after = storage::get_saved_queries_direct(&state) + .await + .expect("get_saved_queries after delete should succeed"); + assert!(!all_after.iter().any(|q| q.id == created.id)); +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_transfer_export_and_import_minimal_flow() { + let docker = (!postgres_context::should_reuse_local_db()).then(Cli::default); + let (_container, form) = postgres_context::postgres_form_from_test_context(docker.as_ref()); + wait_until_postgres_ready(&form).await; + let state = init_state_with_local_db().await; + let conn_id = create_postgres_connection_for_state(&state, &form, "transfer-minimal").await; + let database = form + .database + .clone() + .unwrap_or_else(|| "postgres".to_string()); + let schema = "public".to_string(); + let table = unique_name("dbpaw_transfer_src"); + let qualified = format!("\"{}\".\"{}\"", schema, table); + let driver = PostgresDriver::connect(&form) + .await + .expect("failed to connect postgres driver"); + let _ = driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + driver + .execute_query(format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name TEXT)", + qualified + )) + .await + .expect("create transfer src table should succeed"); + driver + .execute_query(format!( + "INSERT INTO {} (id, name) VALUES (1, 'a'), (2, 'b')", + qualified + )) + .await + .expect("insert transfer src rows should succeed"); + driver.close().await; + + let base = std::env::temp_dir().join(unique_name("dbpaw_transfer_it")); + fs::create_dir_all(&base).expect("create temp transfer dir should succeed"); + let table_export_path = base.join("table_export.csv"); + let query_export_path = base.join("query_export.json"); + let import_sql_path = base.join("import.sql"); + + let table_export = transfer::export_table_data_direct( + &state, + conn_id, + Some(database.clone()), + schema.clone(), + table.clone(), + "postgres".to_string(), + transfer::ExportFormat::Csv, + transfer::ExportScope::FullTable, + None, + None, + None, + None, + None, + None, + Some(table_export_path.to_string_lossy().to_string()), + Some(100), + ) + .await + .expect("export_table_data should succeed"); + assert!(table_export.row_count >= 2); + assert!(std::path::Path::new(&table_export.file_path).exists()); + + let query_export = transfer::export_query_result_direct( + &state, + conn_id, + Some(database.clone()), + format!("SELECT * FROM {} ORDER BY id", qualified), + "postgres".to_string(), + transfer::ExportFormat::Json, + Some(query_export_path.to_string_lossy().to_string()), + ) + .await + .expect("export_query_result should succeed"); + assert!(query_export.row_count >= 2); + assert!(std::path::Path::new(&query_export.file_path).exists()); + + let import_table = unique_name("dbpaw_import_dst"); + let import_sql = format!( + "CREATE TABLE \"{}\".\"{}\" (id INT PRIMARY KEY, name TEXT);", + schema, import_table + ); + fs::write(&import_sql_path, import_sql).expect("write import sql file should succeed"); + let import_result = transfer::import_sql_file_direct( + &state, + conn_id, + Some(database.clone()), + import_sql_path.to_string_lossy().to_string(), + "postgres".to_string(), + ) + .await + .expect("import_sql_file should succeed"); + assert!(import_result.success_statements >= 1); + assert!(import_result.error.is_none()); + + let cleanup_driver = PostgresDriver::connect(&form) + .await + .expect("failed to connect postgres driver for transfer cleanup"); + let _ = cleanup_driver + .execute_query(format!("DROP TABLE IF EXISTS {}", qualified)) + .await; + let _ = cleanup_driver + .execute_query(format!( + "DROP TABLE IF EXISTS \"{}\".\"{}\"", + schema, import_table + )) + .await; + cleanup_driver.close().await; + let _ = fs::remove_file(table_export_path); + let _ = fs::remove_file(query_export_path); + let _ = fs::remove_file(import_sql_path); + let _ = fs::remove_dir_all(base); + let _ = connection::delete_connection_direct(&state, conn_id).await; +} + +#[tokio::test] +#[ignore] +async fn test_postgres_command_ai_minimal_provider_conversation_and_chat_flow() { + let state = init_state_with_local_db().await; + + let start_without_provider = ai::ai_chat_start_direct( + &state, + AiChatRequest { + request_id: unique_name("ai_start_no_provider"), + provider_id: None, + conversation_id: None, + scenario: "sql_generate".to_string(), + input: "select 1".to_string(), + title: Some("phase5 start".to_string()), + connection_id: None, + database: None, + schema_overview: None, + selected_tables: None, + }, + ) + .await; + assert!(start_without_provider.is_err()); + + let created_provider = ai::ai_create_provider_direct( + &state, + AiProviderForm { + name: unique_name("ai_provider"), + provider_type: Some("openai".to_string()), + base_url: "https://example.invalid/v1".to_string(), + model: "gpt-4o-mini".to_string(), + api_key: Some("sk-test".to_string()), + is_default: Some(true), + enabled: Some(true), + extra_json: None, + }, + ) + .await + .expect("ai_create_provider should succeed"); + let providers = ai::ai_list_providers_direct(&state) + .await + .expect("ai_list_providers should succeed"); + assert!(providers.iter().any(|p| p.id == created_provider.id)); + + let updated_provider = ai::ai_update_provider_direct( + &state, + created_provider.id, + AiProviderForm { + name: format!("{}_updated", created_provider.name), + provider_type: Some("openai".to_string()), + base_url: "https://example.invalid/v1".to_string(), + model: "gpt-4o-mini".to_string(), + api_key: Some("sk-test-2".to_string()), + is_default: Some(true), + enabled: Some(true), + extra_json: None, + }, + ) + .await + .expect("ai_update_provider should succeed"); + assert_eq!(updated_provider.id, created_provider.id); + + ai::ai_set_default_provider_direct(&state, created_provider.id) + .await + .expect("ai_set_default_provider should succeed"); + ai::ai_clear_provider_api_key_direct(&state, "openai".to_string()) + .await + .expect("ai_clear_provider_api_key should succeed"); + + let continue_without_conversation = ai::ai_chat_continue_direct( + &state, + AiChatRequest { + request_id: unique_name("ai_continue_no_conv"), + provider_id: Some(created_provider.id), + conversation_id: None, + scenario: "sql_generate".to_string(), + input: "continue".to_string(), + title: None, + connection_id: None, + database: None, + schema_overview: None, + selected_tables: None, + }, + ) + .await; + assert!(continue_without_conversation.is_err()); + + let db = get_local_db(&state).await; + let conv = db + .create_ai_conversation( + unique_name("ai_conv"), + "sql_generate".to_string(), + None, + None, + ) + .await + .expect("create ai conversation in local db should succeed"); + let conversations = ai::ai_list_conversations_direct(&state, None, None) + .await + .expect("ai_list_conversations should succeed"); + assert!(conversations.iter().any(|c| c.id == conv.id)); + let detail = ai::ai_get_conversation_direct(&state, conv.id) + .await + .expect("ai_get_conversation should succeed"); + assert_eq!(detail.conversation.id, conv.id); + ai::ai_delete_conversation_direct(&state, conv.id) + .await + .expect("ai_delete_conversation should succeed"); + let conversations_after = ai::ai_list_conversations_direct(&state, None, None) + .await + .expect("ai_list_conversations after delete should succeed"); + assert!(!conversations_after.iter().any(|c| c.id == conv.id)); + + ai::ai_delete_provider_direct(&state, created_provider.id) + .await + .expect("ai_delete_provider should succeed"); +} diff --git a/src-tauri/tests/sqlite_integration.rs b/src-tauri/tests/sqlite_integration.rs index 871fa47..016950a 100644 --- a/src-tauri/tests/sqlite_integration.rs +++ b/src-tauri/tests/sqlite_integration.rs @@ -14,6 +14,18 @@ fn sqlite_test_path() -> PathBuf { p } +fn json_to_i64(value: &serde_json::Value) -> i64 { + if let Some(v) = value.as_i64() { + return v; + } + if let Some(v) = value.as_str() { + return v + .parse::() + .expect("string value should be numeric for integer assertion"); + } + panic!("value should be i64 or numeric string, got {}", value); +} + #[tokio::test] #[ignore] async fn test_sqlite_integration_flow() { @@ -43,7 +55,7 @@ async fn test_sqlite_integration_flow() { driver .execute_query( - "CREATE TABLE IF NOT EXISTS sqlite_type_probe (\ + "CREATE TABLE IF NOT EXISTS dbpaw_sqlite_type_probe (\ id INTEGER PRIMARY KEY, \ name TEXT, \ amount NUMERIC, \ @@ -57,8 +69,8 @@ async fn test_sqlite_integration_flow() { driver .execute_query( - "CREATE VIEW IF NOT EXISTS sqlite_type_probe_v AS \ - SELECT id, name FROM sqlite_type_probe" + "CREATE VIEW IF NOT EXISTS dbpaw_sqlite_type_probe_v AS \ + SELECT id, name FROM dbpaw_sqlite_type_probe" .to_string(), ) .await @@ -66,7 +78,7 @@ async fn test_sqlite_integration_flow() { driver .execute_query( - "INSERT INTO sqlite_type_probe (id, name, amount, payload, created_at) \ + "INSERT INTO dbpaw_sqlite_type_probe (id, name, amount, payload, created_at) \ VALUES (1, 'hello', 12.34, x'DEADBEEF', '2026-01-02 03:04:05')" .to_string(), ) @@ -75,16 +87,16 @@ async fn test_sqlite_integration_flow() { let tables = driver.list_tables(None).await.expect("list_tables failed"); assert!( - tables.iter().any(|t| t.name == "sqlite_type_probe"), - "list_tables should include sqlite_type_probe" + tables.iter().any(|t| t.name == "dbpaw_sqlite_type_probe"), + "list_tables should include dbpaw_sqlite_type_probe" ); assert!( - tables.iter().any(|t| t.name == "sqlite_type_probe_v"), - "list_tables should include sqlite_type_probe_v" + tables.iter().any(|t| t.name == "dbpaw_sqlite_type_probe_v"), + "list_tables should include dbpaw_sqlite_type_probe_v" ); let metadata = driver - .get_table_metadata("main".to_string(), "sqlite_type_probe".to_string()) + .get_table_metadata("main".to_string(), "dbpaw_sqlite_type_probe".to_string()) .await .expect("get_table_metadata failed"); assert!( @@ -100,7 +112,7 @@ async fn test_sqlite_integration_flow() { ); let ddl = driver - .get_table_ddl("main".to_string(), "sqlite_type_probe".to_string()) + .get_table_ddl("main".to_string(), "dbpaw_sqlite_type_probe".to_string()) .await .expect("get_table_ddl failed"); assert!( @@ -110,7 +122,7 @@ async fn test_sqlite_integration_flow() { let result = driver .execute_query( - "SELECT id, name, amount, payload, created_at FROM sqlite_type_probe WHERE id = 1" + "SELECT id, name, amount, payload, created_at FROM dbpaw_sqlite_type_probe WHERE id = 1" .to_string(), ) .await @@ -126,12 +138,845 @@ async fn test_sqlite_integration_flow() { assert!(row.get("payload").is_some(), "payload should exist"); let _ = driver - .execute_query("DROP VIEW IF EXISTS sqlite_type_probe_v".to_string()) + .execute_query("DROP VIEW IF EXISTS dbpaw_sqlite_type_probe_v".to_string()) .await; let _ = driver - .execute_query("DROP TABLE IF EXISTS sqlite_type_probe".to_string()) + .execute_query("DROP TABLE IF EXISTS dbpaw_sqlite_type_probe".to_string()) + .await; + driver.close().await; + + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_get_table_data_supports_pagination_sort_filter_and_order_by() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str), + ..Default::default() + }; + let driver = SqliteDriver::connect(&form) + .await + .expect("Failed to connect to sqlite db"); + + driver + .execute_query( + "CREATE TABLE dbpaw_sqlite_grid_probe (id INTEGER PRIMARY KEY, name TEXT, score INTEGER)" + .to_string(), + ) + .await + .expect("create dbpaw_sqlite_grid_probe failed"); + driver + .execute_query( + "INSERT INTO dbpaw_sqlite_grid_probe (id, name, score) VALUES \ + (1, 'alpha', 10), (2, 'beta', 20), (3, 'gamma', 30), (4, 'delta', 40)" + .to_string(), + ) + .await + .expect("insert dbpaw_sqlite_grid_probe failed"); + + let page1 = driver + .get_table_data( + "main".to_string(), + "dbpaw_sqlite_grid_probe".to_string(), + 1, + 2, + Some("score".to_string()), + Some("desc".to_string()), + None, + None, + ) + .await + .expect("get_table_data page1 failed"); + assert_eq!(page1.total, 4); + assert_eq!(page1.data.len(), 2); + assert_eq!( + page1.data[0]["name"], + serde_json::Value::String("delta".to_string()) + ); + + let filtered = driver + .get_table_data( + "main".to_string(), + "dbpaw_sqlite_grid_probe".to_string(), + 1, + 10, + None, + None, + Some("score >= 20".to_string()), + None, + ) + .await + .expect("get_table_data with filter failed"); + assert_eq!(filtered.total, 3); + + let ordered = driver + .get_table_data( + "main".to_string(), + "dbpaw_sqlite_grid_probe".to_string(), + 1, + 1, + Some("id".to_string()), + Some("asc".to_string()), + None, + Some("name DESC".to_string()), + ) + .await + .expect("get_table_data with order_by failed"); + assert_eq!(ordered.total, 4); + assert_eq!(ordered.data.len(), 1); + assert_eq!( + ordered.data[0]["name"], + serde_json::Value::String("gamma".to_string()) + ); + + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_get_table_data_rejects_invalid_sort_column() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str), + ..Default::default() + }; + let driver = SqliteDriver::connect(&form) + .await + .expect("Failed to connect to sqlite db"); + + driver + .execute_query("CREATE TABLE dbpaw_sqlite_invalid_sort_probe (id INTEGER PRIMARY KEY)".to_string()) + .await + .expect("create dbpaw_sqlite_invalid_sort_probe failed"); + + let result = driver + .get_table_data( + "main".to_string(), + "dbpaw_sqlite_invalid_sort_probe".to_string(), + 1, + 10, + Some("id desc".to_string()), + Some("desc".to_string()), + None, + None, + ) .await; + let err = result.expect_err("invalid sort column should return error"); + assert!( + err.contains("[VALIDATION_ERROR] Invalid sort column name"), + "unexpected error: {}", + err + ); + + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_table_structure_and_schema_overview() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str), + ..Default::default() + }; + let driver = SqliteDriver::connect(&form) + .await + .expect("Failed to connect to sqlite db"); + + driver + .execute_query( + "CREATE TABLE dbpaw_sqlite_overview_probe (id INTEGER PRIMARY KEY, label TEXT NOT NULL)" + .to_string(), + ) + .await + .expect("create dbpaw_sqlite_overview_probe failed"); + + let structure = driver + .get_table_structure("main".to_string(), "dbpaw_sqlite_overview_probe".to_string()) + .await + .expect("get_table_structure failed"); + assert!( + structure.columns.iter().any(|c| c.name == "id" && c.primary_key), + "table structure should include primary key id" + ); + assert!( + structure.columns.iter().any(|c| c.name == "label"), + "table structure should include label" + ); + + let overview = driver + .get_schema_overview(Some("main".to_string())) + .await + .expect("get_schema_overview failed"); + assert!( + overview + .tables + .iter() + .any(|t| t.schema == "main" && t.name == "dbpaw_sqlite_overview_probe"), + "schema overview should include main.dbpaw_sqlite_overview_probe" + ); + + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_metadata_includes_indexes_and_foreign_keys() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str), + ..Default::default() + }; + let driver = SqliteDriver::connect(&form) + .await + .expect("Failed to connect to sqlite db"); + + driver + .execute_query( + "CREATE TABLE dbpaw_sqlite_parent_meta_probe (id INTEGER PRIMARY KEY); \ + CREATE TABLE dbpaw_sqlite_child_meta_probe (\ + id INTEGER PRIMARY KEY, \ + parent_id INTEGER NOT NULL, \ + name TEXT, \ + CONSTRAINT fk_dbpaw_sqlite_child_parent FOREIGN KEY(parent_id) REFERENCES dbpaw_sqlite_parent_meta_probe(id)\ + ); \ + CREATE INDEX idx_dbpaw_sqlite_child_name ON dbpaw_sqlite_child_meta_probe(name);" + .to_string(), + ) + .await + .expect("create sqlite metadata probe tables failed"); + + let metadata = driver + .get_table_metadata("main".to_string(), "dbpaw_sqlite_child_meta_probe".to_string()) + .await + .expect("get_table_metadata failed"); + assert!( + metadata + .indexes + .iter() + .any(|i| i.name == "idx_dbpaw_sqlite_child_name" && i.columns.contains(&"name".to_string())), + "metadata should include idx_dbpaw_sqlite_child_name" + ); + assert!( + metadata + .foreign_keys + .iter() + .any(|fk| fk.column == "parent_id" && fk.referenced_table == "dbpaw_sqlite_parent_meta_probe"), + "metadata should include FK parent_id -> dbpaw_sqlite_parent_meta_probe(id)" + ); + + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_boolean_and_json_type_mapping_regression() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str), + ..Default::default() + }; + let driver = SqliteDriver::connect(&form) + .await + .expect("Failed to connect to sqlite db"); + + driver + .execute_query( + "CREATE TABLE dbpaw_sqlite_bool_json_probe (id INTEGER PRIMARY KEY, flag BOOLEAN, meta TEXT)" + .to_string(), + ) + .await + .expect("create dbpaw_sqlite_bool_json_probe failed"); + driver + .execute_query( + "INSERT INTO dbpaw_sqlite_bool_json_probe (id, flag, meta) VALUES \ + (1, 1, '{\"tier\":\"gold\"}')" + .to_string(), + ) + .await + .expect("insert dbpaw_sqlite_bool_json_probe failed"); + + let query_result = driver + .execute_query( + "SELECT flag, json_extract(meta, '$.tier') AS tier \ + FROM dbpaw_sqlite_bool_json_probe WHERE id = 1" + .to_string(), + ) + .await + .expect("select bool/json probe row failed"); + assert_eq!(query_result.row_count, 1); + let query_row = query_result.data.first().expect("query row should exist"); + assert_eq!(query_row["flag"], serde_json::Value::Bool(true)); + assert_eq!(query_row["tier"], serde_json::Value::String("gold".to_string())); + + let table_data = driver + .get_table_data( + "main".to_string(), + "dbpaw_sqlite_bool_json_probe".to_string(), + 1, + 10, + None, + None, + None, + None, + ) + .await + .expect("get_table_data for dbpaw_sqlite_bool_json_probe failed"); + assert_eq!(table_data.total, 1); + let grid_row = table_data.data.first().expect("table row should exist"); + assert_eq!(grid_row["flag"], serde_json::Value::Bool(true)); + assert!( + grid_row.get("meta").is_some(), + "meta should exist in table_data" + ); + + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_transaction_commit_and_rollback() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str), + ..Default::default() + }; + let driver = SqliteDriver::connect(&form) + .await + .expect("Failed to connect to sqlite db"); + + driver + .execute_query( + "CREATE TABLE dbpaw_sqlite_txn_probe (id INTEGER PRIMARY KEY, name TEXT)".to_string(), + ) + .await + .expect("create sqlite txn probe table failed"); + + let mut rollback_tx = driver.pool.begin().await.expect("begin rollback tx failed"); + sqlx::query("INSERT INTO dbpaw_sqlite_txn_probe (id, name) VALUES (?, ?)") + .bind(1_i64) + .bind("rolled_back") + .execute(&mut *rollback_tx) + .await + .expect("insert in rollback tx failed"); + rollback_tx.rollback().await.expect("rollback tx failed"); + + let rolled_back = driver + .execute_query("SELECT COUNT(*) AS c FROM dbpaw_sqlite_txn_probe WHERE id = 1".to_string()) + .await + .expect("count after rollback failed"); + let rolled_back_count = json_to_i64(&rolled_back.data[0]["c"]); + assert_eq!(rolled_back_count, 0); + + let mut commit_tx = driver.pool.begin().await.expect("begin commit tx failed"); + sqlx::query("INSERT INTO dbpaw_sqlite_txn_probe (id, name) VALUES (?, ?)") + .bind(2_i64) + .bind("committed") + .execute(&mut *commit_tx) + .await + .expect("insert in commit tx failed"); + commit_tx.commit().await.expect("commit tx failed"); + + let committed = driver + .execute_query("SELECT COUNT(*) AS c FROM dbpaw_sqlite_txn_probe WHERE id = 2".to_string()) + .await + .expect("count after commit failed"); + let committed_count = json_to_i64(&committed.data[0]["c"]); + assert_eq!(committed_count, 1); + + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_execute_query_reports_affected_rows_for_update_delete() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str), + ..Default::default() + }; + let driver = SqliteDriver::connect(&form) + .await + .expect("Failed to connect to sqlite db"); + + driver + .execute_query( + "CREATE TABLE dbpaw_sqlite_affected_rows_probe (id INTEGER PRIMARY KEY, name TEXT)" + .to_string(), + ) + .await + .expect("create affected_rows probe table failed"); + + let inserted = driver + .execute_query( + "INSERT INTO dbpaw_sqlite_affected_rows_probe (id, name) VALUES (1, 'a'), (2, 'b')" + .to_string(), + ) + .await + .expect("insert affected_rows probe rows failed"); + assert_eq!(inserted.row_count, 2); + + let updated = driver + .execute_query( + "UPDATE dbpaw_sqlite_affected_rows_probe SET name = 'bb' WHERE id = 2".to_string(), + ) + .await + .expect("update affected_rows probe row failed"); + assert_eq!(updated.row_count, 1); + + let deleted = driver + .execute_query("DELETE FROM dbpaw_sqlite_affected_rows_probe WHERE id IN (1, 2)".to_string()) + .await + .expect("delete affected_rows probe rows failed"); + assert_eq!(deleted.row_count, 2); + + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_large_text_and_blob_round_trip() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str), + ..Default::default() + }; + let driver = SqliteDriver::connect(&form) + .await + .expect("Failed to connect to sqlite db"); + + driver + .execute_query( + "CREATE TABLE dbpaw_sqlite_large_field_probe (id INTEGER PRIMARY KEY, body TEXT, payload BLOB)" + .to_string(), + ) + .await + .expect("create large field probe table failed"); + + let large_body = "x".repeat(70000); + let large_payload = vec![0xAB_u8; 2048]; + let mut conn = driver + .pool + .acquire() + .await + .expect("acquire sqlite pooled connection failed"); + sqlx::query("INSERT INTO dbpaw_sqlite_large_field_probe (id, body, payload) VALUES (?, ?, ?)") + .bind(1_i64) + .bind(&large_body) + .bind(&large_payload) + .execute(&mut *conn) + .await + .expect("insert large field probe row failed"); + drop(conn); + + let result = driver + .execute_query( + "SELECT body, payload FROM dbpaw_sqlite_large_field_probe WHERE id = 1".to_string(), + ) + .await + .expect("select large field probe row failed"); + assert_eq!(result.row_count, 1); + let row = result.data.first().expect("large field row should exist"); + let body = row + .get("body") + .and_then(|v| v.as_str()) + .expect("body should be string"); + assert_eq!(body.len(), 70000); + assert!(row.get("payload").is_some(), "payload should exist"); + + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_error_handling_for_sql_error() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str), + ..Default::default() + }; + let driver = SqliteDriver::connect(&form) + .await + .expect("Failed to connect to sqlite db"); + + let err = driver + .execute_query("SELECT * FROM __dbpaw_table_not_exists".to_string()) + .await + .expect_err("invalid SQL should return query error"); + assert!( + err.contains("[QUERY_ERROR]"), + "unexpected error shape: {}", + err + ); + + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_concurrent_connections_can_query() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let mut handles = Vec::new(); + + for _ in 0..8 { + let task_form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str.clone()), + ..Default::default() + }; + handles.push(tokio::spawn(async move { + let driver = SqliteDriver::connect(&task_form) + .await + .expect("connect sqlite in concurrent task failed"); + let result = driver.execute_query("SELECT 1 AS ok".to_string()).await; + driver.close().await; + result + })); + } + + for handle in handles { + let result = handle.await.expect("concurrent sqlite task panicked"); + let data = result.expect("concurrent sqlite query failed"); + assert_eq!(data.row_count, 1); + let ok = &data.data[0]["ok"]; + let matches = *ok == serde_json::Value::Number(1.into()) || ok == "1"; + assert!(matches, "ok should be 1, got {}", ok); + } + + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_view_can_be_listed_and_queried() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str), + ..Default::default() + }; + let driver = SqliteDriver::connect(&form) + .await + .expect("Failed to connect to sqlite db"); + + driver + .execute_query( + "CREATE TABLE dbpaw_sqlite_view_base_probe (id INTEGER PRIMARY KEY, name TEXT, score INTEGER)" + .to_string(), + ) + .await + .expect("create base table for view failed"); + driver + .execute_query( + "INSERT INTO dbpaw_sqlite_view_base_probe (id, name, score) VALUES (1, 'alice', 10), (2, 'bob', 20)" + .to_string(), + ) + .await + .expect("insert base rows for view failed"); + driver + .execute_query( + "CREATE VIEW dbpaw_sqlite_view_probe_v AS SELECT id, name FROM dbpaw_sqlite_view_base_probe WHERE score >= 20" + .to_string(), + ) + .await + .expect("create view failed"); + + let tables = driver.list_tables(None).await.expect("list_tables failed"); + assert!( + tables + .iter() + .any(|t| t.name == "dbpaw_sqlite_view_base_probe" && t.r#type == "table"), + "list_tables should include base table" + ); + assert!( + tables + .iter() + .any(|t| t.name == "dbpaw_sqlite_view_probe_v" && t.r#type == "view"), + "list_tables should include view with type=view" + ); + + let view_rows = driver + .execute_query("SELECT id, name FROM dbpaw_sqlite_view_probe_v ORDER BY id".to_string()) + .await + .expect("select from view failed"); + assert_eq!(view_rows.row_count, 1); + let row = view_rows.data.first().expect("view row should exist"); + let id_matches = row["id"] == serde_json::Value::Number(2.into()) + || row["id"] == serde_json::Value::String("2".to_string()); + assert!(id_matches, "unexpected id payload: {}", row["id"]); + assert_eq!(row["name"], serde_json::Value::String("bob".to_string())); + + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_connection_failure_with_invalid_file_path() { + let mut missing_dir = env::temp_dir(); + missing_dir.push(format!("dbpaw_sqlite_missing_dir_{}", Uuid::new_v4())); + let missing_path = missing_dir.join("db.sqlite"); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(missing_path.to_string_lossy().to_string()), + ..Default::default() + }; + + let err = match SqliteDriver::connect(&form).await { + Ok(_) => panic!("invalid file path should fail"), + Err(err) => err, + }; + assert!( + err.starts_with("[CONN_FAILED]"), + "unexpected error: {}", + err + ); + assert!(!err.trim().is_empty(), "error message should not be empty"); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_lock_conflict_or_busy_error() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str), + ..Default::default() + }; + let driver_a = SqliteDriver::connect(&form) + .await + .expect("Failed to connect sqlite driver A"); + let driver_b = SqliteDriver::connect(&form) + .await + .expect("Failed to connect sqlite driver B"); + + driver_a + .execute_query( + "CREATE TABLE dbpaw_sqlite_lock_probe (id INTEGER PRIMARY KEY, name TEXT)".to_string(), + ) + .await + .expect("create lock probe table failed"); + driver_a + .execute_query("PRAGMA busy_timeout = 100".to_string()) + .await + .expect("set busy_timeout for driver A failed"); + driver_b + .execute_query("PRAGMA busy_timeout = 100".to_string()) + .await + .expect("set busy_timeout for driver B failed"); + + let mut tx = driver_a + .pool + .begin() + .await + .expect("begin write lock tx failed"); + sqlx::query("INSERT INTO dbpaw_sqlite_lock_probe (id, name) VALUES (?, ?)") + .bind(1_i64) + .bind("a") + .execute(&mut *tx) + .await + .expect("insert in lock tx failed"); + + let err = driver_b + .execute_query("INSERT INTO dbpaw_sqlite_lock_probe (id, name) VALUES (2, 'b')".to_string()) + .await + .expect_err("concurrent write under lock should fail"); + assert!( + err.contains("[QUERY_ERROR]"), + "unexpected lock error shape: {}", + err + ); + let lower = err.to_ascii_lowercase(); + assert!( + lower.contains("locked") || lower.contains("busy"), + "unexpected lock/busy error: {}", + err + ); + + tx.rollback().await.expect("rollback lock tx failed"); + driver_a.close().await; + driver_b.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_batch_insert_and_batch_execute_flow() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str), + ..Default::default() + }; + let driver = SqliteDriver::connect(&form) + .await + .expect("Failed to connect to sqlite db"); + + driver + .execute_query( + "CREATE TABLE dbpaw_sqlite_batch_probe (id INTEGER PRIMARY KEY, category TEXT, score INTEGER)" + .to_string(), + ) + .await + .expect("create batch probe table failed"); + + let value_rows: Vec = (1..=50) + .map(|id| { + let category = if id <= 25 { "alpha" } else { "beta" }; + format!("({}, '{}', {})", id, category, id) + }) + .collect(); + let insert_sql = format!( + "INSERT INTO dbpaw_sqlite_batch_probe (id, category, score) VALUES {}", + value_rows.join(", ") + ); + let inserted = driver + .execute_query(insert_sql) + .await + .expect("batch insert failed"); + assert_eq!(inserted.row_count, 50); + + let batch_sqls = vec![ + "UPDATE dbpaw_sqlite_batch_probe SET score = score + 100 WHERE id <= 10".to_string(), + "UPDATE dbpaw_sqlite_batch_probe SET category = 'gamma' WHERE id BETWEEN 30 AND 40" + .to_string(), + "DELETE FROM dbpaw_sqlite_batch_probe WHERE id IN (3, 6, 9, 12, 15)".to_string(), + ]; + let mut affected = Vec::new(); + for sql in batch_sqls { + let result = driver + .execute_query(sql) + .await + .expect("batch execute statement failed"); + affected.push(result.row_count); + } + assert_eq!(affected, vec![10, 11, 5]); + + let check_total = driver + .execute_query("SELECT COUNT(*) AS c FROM dbpaw_sqlite_batch_probe".to_string()) + .await + .expect("count after batch execute failed"); + let total = json_to_i64(&check_total.data[0]["c"]); + assert_eq!(total, 45); + + let check_gamma = driver + .execute_query( + "SELECT COUNT(*) AS c FROM dbpaw_sqlite_batch_probe WHERE category = 'gamma'" + .to_string(), + ) + .await + .expect("count gamma rows failed"); + let gamma = json_to_i64(&check_gamma.data[0]["c"]); + assert_eq!(gamma, 11); + driver.close().await; + let _ = std::fs::remove_file(db_path); +} + +#[tokio::test] +#[ignore] +async fn test_sqlite_prepared_statements_prepare_execute_and_deallocate() { + let db_path = sqlite_test_path(); + let db_path_str = db_path.to_string_lossy().to_string(); + let form = ConnectionForm { + driver: "sqlite".to_string(), + file_path: Some(db_path_str), + ..Default::default() + }; + let driver = SqliteDriver::connect(&form) + .await + .expect("Failed to connect to sqlite db"); + + driver + .execute_query( + "CREATE TABLE dbpaw_sqlite_prepared_stmt_probe (id INTEGER PRIMARY KEY, name TEXT)" + .to_string(), + ) + .await + .expect("create prepared stmt probe table failed"); + + let mut conn = driver + .pool + .acquire() + .await + .expect("acquire sqlite pooled connection failed"); + let prepared_insert_sql = + "INSERT INTO dbpaw_sqlite_prepared_stmt_probe (id, name) VALUES (?, ?)".to_string(); + let insert_a = sqlx::query(&prepared_insert_sql) + .bind(1_i64) + .bind("alice") + .execute(&mut *conn) + .await + .expect("prepared insert alice failed"); + assert_eq!(insert_a.rows_affected(), 1); + let insert_b = sqlx::query(&prepared_insert_sql) + .bind(2_i64) + .bind("bob") + .execute(&mut *conn) + .await + .expect("prepared insert bob failed"); + assert_eq!(insert_b.rows_affected(), 1); + let prepared_update_sql = + "UPDATE dbpaw_sqlite_prepared_stmt_probe SET name = ? WHERE id = ?".to_string(); + let updated = sqlx::query(&prepared_update_sql) + .bind("alice-updated") + .bind(1_i64) + .execute(&mut *conn) + .await + .expect("prepared update failed"); + assert_eq!(updated.rows_affected(), 1); + + let prepared_select_sql = + "SELECT name FROM dbpaw_sqlite_prepared_stmt_probe WHERE id = ?".to_string(); + let selected_name: String = sqlx::query_scalar(&prepared_select_sql) + .bind(1_i64) + .fetch_one(&mut *conn) + .await + .expect("prepared select failed"); + assert_eq!(selected_name, "alice-updated"); + drop(conn); + + let verify = driver + .execute_query("SELECT COUNT(*) AS c FROM dbpaw_sqlite_prepared_stmt_probe".to_string()) + .await + .expect("verify prepared writes failed"); + let total = json_to_i64(&verify.data[0]["c"]); + assert_eq!(total, 2); + + driver.close().await; let _ = std::fs::remove_file(db_path); } diff --git a/src/App.tsx b/src/App.tsx index 9ec60ae..d6a35ba 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -146,7 +146,9 @@ function LazyPanelFallback({ className?: string; }) { return ( -
+
{label}
); @@ -270,11 +272,13 @@ export default function App() { className="h-7 w-7 p-0" onClick={() => setAiVisible((v) => !v)} title={ + aiVisible ? t("app.window.hideAiPanel") : t("app.window.showAiPanel") + } + aria-label={ aiVisible - ? t("app.window.hideAiPanel") - : t("app.window.showAiPanel") + ? t("app.window.hideAiPanelAria") + : t("app.window.showAiPanelAria") } - aria-label={aiVisible ? t("app.window.hideAiPanelAria") : t("app.window.showAiPanelAria")} > @@ -373,54 +377,53 @@ export default function App() { Promise.allSettled([ fetchEditorDatabases(connectionId, initialDatabase), fetchEditorSchemaOverview(connectionId, initialDatabase), - ]) - .then(([availableDatabasesResult, schemaOverviewResult]) => { - if (availableDatabasesResult.status === "rejected") { - console.error( - "Failed to load editor databases:", - availableDatabasesResult.reason instanceof Error - ? availableDatabasesResult.reason.message - : String(availableDatabasesResult.reason), - ); - } - if (schemaOverviewResult.status === "rejected") { - console.error( - "Failed to load schema overview:", - schemaOverviewResult.reason instanceof Error - ? schemaOverviewResult.reason.message - : String(schemaOverviewResult.reason), - ); - } + ]).then(([availableDatabasesResult, schemaOverviewResult]) => { + if (availableDatabasesResult.status === "rejected") { + console.error( + "Failed to load editor databases:", + availableDatabasesResult.reason instanceof Error + ? availableDatabasesResult.reason.message + : String(availableDatabasesResult.reason), + ); + } + if (schemaOverviewResult.status === "rejected") { + console.error( + "Failed to load schema overview:", + schemaOverviewResult.reason instanceof Error + ? schemaOverviewResult.reason.message + : String(schemaOverviewResult.reason), + ); + } - const availableDatabases = - availableDatabasesResult.status === "fulfilled" - ? availableDatabasesResult.value - : normalizeDatabaseOptions( - initialDatabase ? [initialDatabase] : [], - initialDatabase, - ); - const schemaOverview = - schemaOverviewResult.status === "fulfilled" - ? schemaOverviewResult.value - : undefined; + const availableDatabases = + availableDatabasesResult.status === "fulfilled" + ? availableDatabasesResult.value + : normalizeDatabaseOptions( + initialDatabase ? [initialDatabase] : [], + initialDatabase, + ); + const schemaOverview = + schemaOverviewResult.status === "fulfilled" + ? schemaOverviewResult.value + : undefined; - setTabs((prev) => - prev.map((t) => - t.id === newTabId - ? { - ...t, - database: resolvePreferredDatabase({ - preferredDatabase: initialDatabase, - connectionDatabase: initialDatabase, - availableDatabases, - }), + setTabs((prev) => + prev.map((t) => + t.id === newTabId + ? { + ...t, + database: resolvePreferredDatabase({ + preferredDatabase: initialDatabase, + connectionDatabase: initialDatabase, availableDatabases, - schemaOverview, - } - : t, - ), - ); - }); + }), + availableDatabases, + schemaOverview, + } + : t, + ), + ); + }); }; const handleOpenSavedQuery = async (query: SavedQuery) => { @@ -578,14 +581,16 @@ export default function App() { tab.connectionId, database, ); - if (schemaOverviewRequestKeysRef.current.get(tabId) !== requestKey) return; + if (schemaOverviewRequestKeysRef.current.get(tabId) !== requestKey) + return; setTabs((prev) => prev.map((item) => item.id === tabId ? { ...item, schemaOverview } : item, ), ); } catch (e) { - if (schemaOverviewRequestKeysRef.current.get(tabId) !== requestKey) return; + if (schemaOverviewRequestKeysRef.current.get(tabId) !== requestKey) + return; const errorMessage = e instanceof Error ? e.message : String(e); console.error("Failed to switch editor database", errorMessage); toast.error(t("app.error.loadSchemaOverview"), { @@ -623,7 +628,9 @@ export default function App() { .slice(2, 8)}`; setTabs((prev) => prev.map((t) => - t.id === tabId ? { ...t, activeQueryId: queryId, lastQueryId: queryId } : t, + t.id === tabId + ? { ...t, activeQueryId: queryId, lastQueryId: queryId } + : t, ), ); try { @@ -762,9 +769,12 @@ export default function App() { scope: "full_table", filePath, }); - toast.success(t("app.success.exportCompleted", { count: result.rowCount }), { - description: result.filePath, - }); + toast.success( + t("app.success.exportCompleted", { count: result.rowCount }), + { + description: result.filePath, + }, + ); } catch (e) { toast.error(t("app.error.exportFailed"), { description: e instanceof Error ? e.message : String(e), @@ -1447,67 +1457,72 @@ export default function App() { : tab.title; return ( - - - {/* Wrapper avoids data-state conflict: ContextMenu and Tabs both set it; only the trigger must get Tabs' data-state=active for the indicator bar */} - - { - if (e.button === 1) { - e.preventDefault(); - handleCloseTab(tab.id); - } - }} - > -
- {tab.type === "table" ? ( - - ) : ( - - )} - - - {title} - - {tab.type === "editor" && tab.isDirty && ( - - )} - - - - - - - - handleCloseTab(tab.id)} - > - {t("app.tab.closeTab")} - - handleCloseOtherTabs(tab.id)} - > - {t("app.tab.closeOtherTabs")} - - - - + } + }} + > +
+ {tab.type === "table" ? ( +
+ ) : ( + + )} + + + {title} + + {tab.type === "editor" && + tab.isDirty && ( + + )} + + + + + + + + handleCloseTab(tab.id)} + > + {t("app.tab.closeTab")} + + handleCloseOtherTabs(tab.id)} + > + {t("app.tab.closeOtherTabs")} + + + + ); })} @@ -1529,9 +1544,7 @@ export default function App() {
-

- {t("app.empty.hint")} -

+

{t("app.empty.hint")}

) : ( @@ -1544,9 +1557,7 @@ export default function App() { {tab.type === "editor" ? ( + } > - } + fallback={} > (new Map()); const [insertDraftRows, setInsertDraftRows] = useState([]); const [primaryKeys, setPrimaryKeys] = useState([]); + const [clickhouseEngine, setClickhouseEngine] = useState(null); const [tableColumns, setTableColumns] = useState([]); const [columnComments, setColumnComments] = useState>( {}, @@ -391,6 +396,7 @@ export function TableView({ useEffect(() => { if (!tableContext) { setPrimaryKeys([]); + setClickhouseEngine(null); setTableColumns([]); setColumnComments({}); return; @@ -405,6 +411,7 @@ export function TableView({ .then((meta) => { const pks = meta.columns.filter((c) => c.primaryKey).map((c) => c.name); setPrimaryKeys(pks); + setClickhouseEngine(meta.clickhouseExtra?.engine || null); setTableColumns(meta.columns); const comments: Record = {}; @@ -419,6 +426,7 @@ export function TableView({ .catch((e) => { console.error("Failed to fetch primary keys:", e); setPrimaryKeys([]); + setClickhouseEngine(null); setTableColumns([]); setColumnComments({}); }); @@ -446,10 +454,42 @@ export function TableView({ setSaveError(null); }, [data, page]); - const isReadOnlyDriver = tableContext?.driver === "clickhouse"; - const isEditable = - !!tableContext && !isReadOnlyDriver && primaryKeys.length > 0; - const isEditableForUpdates = isEditable && !hasLocalClientSort; + const isClickHouseDriver = tableContext?.driver === "clickhouse"; + const hasPrimaryKeys = primaryKeys.length > 0; + const canInsert = !!tableContext && + (isClickHouseDriver + ? isClickHouseMergeTreeEngine(clickhouseEngine) + : hasPrimaryKeys); + const canUpdateDelete = !!tableContext && + (isClickHouseDriver + ? canMutateClickHouseTable(clickhouseEngine, primaryKeys) + : hasPrimaryKeys); + const isEditableForUpdates = canUpdateDelete && !hasLocalClientSort; + const mutabilityHint = useMemo(() => { + if (!tableContext) return null; + if (hasLocalClientSort) { + return "Inline cell editing is disabled while client-side sorting is active."; + } + if (isClickHouseDriver) { + if (!isClickHouseMergeTreeEngine(clickhouseEngine)) { + return "ClickHouse inline write is only enabled for MergeTree-family tables."; + } + if (!hasPrimaryKeys) { + return "ClickHouse table update/delete requires primary key columns."; + } + return null; + } + if (!hasPrimaryKeys) { + return "This table has no primary key and does not support inline editing"; + } + return null; + }, [ + tableContext, + hasLocalClientSort, + isClickHouseDriver, + clickhouseEngine, + hasPrimaryKeys, + ]); const pendingMutationCount = pendingChanges.size + insertDraftRows.length; const hasPendingChanges = pendingMutationCount > 0; @@ -629,7 +669,7 @@ export function TableView({ // --- SQL generation & save --- const generateUpdateSQL = useCallback(() => { - if (!tableContext || primaryKeys.length === 0) return []; + if (!tableContext || !canUpdateDelete || primaryKeys.length === 0) return []; // Group changes by source row index const changesByRow = new Map(); @@ -671,15 +711,27 @@ export function TableView({ const tableName = getQualifiedTableName(driver, schema, table); - const sql = `UPDATE ${tableName} SET ${setClauses.join(", ")} WHERE ${whereClauses.join(" AND ")}`; + const sql = buildUpdateStatement( + driver, + tableName, + setClauses.join(", "), + whereClauses.join(" AND "), + ); sqls.push(sql); }); return sqls; - }, [tableContext, primaryKeys, pendingChanges, data, currentData]); + }, [ + tableContext, + canUpdateDelete, + primaryKeys, + pendingChanges, + data, + currentData, + ]); const generateInsertSQL = useCallback(() => { - if (!tableContext || !insertDraftRows.length) return []; + if (!tableContext || !canInsert || !insertDraftRows.length) return []; const tableName = getQualifiedTableName( tableContext.driver, tableContext.schema, @@ -716,7 +768,9 @@ export function TableView({ }); if (!insertColumns.length) { - throw new Error(`Row ${index + 1}: at least one column value is required`); + throw new Error( + `Row ${index + 1}: at least one column value is required`, + ); } sqls.push( @@ -725,10 +779,15 @@ export function TableView({ }); return sqls; - }, [tableContext, insertDraftRows, tableColumns, columns]); + }, [tableContext, canInsert, insertDraftRows, tableColumns, columns]); const buildDeleteSQL = useCallback(() => { - if (!tableContext || !selectedRows.size || primaryKeys.length === 0) { + if ( + !tableContext || + !canUpdateDelete || + !selectedRows.size || + primaryKeys.length === 0 + ) { return ""; } @@ -758,10 +817,15 @@ export function TableView({ tableContext.schema, tableContext.table, ); - return `DELETE FROM ${tableName} WHERE ${rowClauses.join(" OR ")}`; - }, [tableContext, selectedRows, primaryKeys, currentData]); + return buildDeleteStatement( + tableContext.driver, + tableName, + rowClauses.join(" OR "), + ); + }, [tableContext, canUpdateDelete, selectedRows, primaryKeys, currentData]); const handleAddDraftRow = useCallback(() => { + if (!canInsert) return; const tempId = `draft_${Date.now()}_${Math.random().toString(36).slice(2, 8)}`; const values = columns.reduce>((acc, column) => { acc[column] = ""; @@ -769,7 +833,7 @@ export function TableView({ }, {}); setInsertDraftRows((prev) => [...prev, { tempId, values }]); setPendingFocusDraftId(tempId); - }, [columns]); + }, [canInsert, columns]); const handleDraftValueChange = useCallback( (tempId: string, column: string, value: string) => { @@ -784,8 +848,26 @@ export function TableView({ [], ); + const refreshAfterMutation = useCallback(async () => { + if (!onDataRefresh) return; + const runRefresh = async () => { + const ret = onDataRefresh(); + if (ret && typeof (ret as Promise).then === "function") { + await ret; + } + }; + + await runRefresh(); + if (tableContext?.driver === "clickhouse") { + await new Promise((resolve) => setTimeout(resolve, 350)); + await runRefresh(); + } + }, [onDataRefresh, tableContext?.driver]); + const handleConfirmDelete = useCallback(async () => { - if (!tableContext || !selectedRows.size || isDeleting) return; + if (!tableContext || !canUpdateDelete || !selectedRows.size || isDeleting) { + return; + } const sql = buildDeleteSQL(); if (!sql) { @@ -803,10 +885,13 @@ export function TableView({ "table_view_save", ); setDeleteDialogOpen(false); - setSelectedRows(new Set()); + const nextSelectedRows = new Set(); + selectedRowsRef.current = nextSelectedRows; + setSelectedRows(nextSelectedRows); + selectedCellRef.current = null; setSelectedCell(null); setEditingCell(null); - onDataRefresh?.(); + await refreshAfterMutation(); } catch (e) { setSaveError( `Delete failed:\n${sql}\n -> ${e instanceof Error ? e.message : String(e)}`, @@ -814,7 +899,14 @@ export function TableView({ } finally { setIsDeleting(false); } - }, [tableContext, selectedRows, isDeleting, buildDeleteSQL, onDataRefresh]); + }, [ + tableContext, + canUpdateDelete, + selectedRows, + isDeleting, + buildDeleteSQL, + refreshAfterMutation, + ]); const handleSave = useCallback(async () => { if (!tableContext || !hasPendingChanges) return; @@ -864,14 +956,14 @@ export function TableView({ setPendingChanges(new Map()); setInsertDraftRows([]); setSaveError(null); - onDataRefresh?.(); + await refreshAfterMutation(); } }, [ tableContext, hasPendingChanges, generateUpdateSQL, generateInsertSQL, - onDataRefresh, + refreshAfterMutation, ]); const handleRefreshClick = useCallback(async () => { @@ -1055,7 +1147,7 @@ export function TableView({ const buildRowsUpdateSQL = useCallback( (rowIndexes: number[]) => { - if (!tableContext || primaryKeys.length === 0) return ""; + if (!tableContext || !canUpdateDelete || primaryKeys.length === 0) return ""; const orderedRows = [...rowIndexes].sort((a, b) => a - b); const { schema, table, driver } = tableContext; const tableName = getQualifiedTableName(driver, schema, table); @@ -1087,12 +1179,19 @@ export function TableView({ return `${quoteIdent(driver, pk)} = '${escapeSQL(String(pkValue))}'`; }); - return `UPDATE ${tableName} SET ${setClauses.join(", ")} WHERE ${whereClauses.join(" AND ")};`; + return `${buildUpdateStatement(driver, tableName, setClauses.join(", "), whereClauses.join(" AND "))};`; }) .filter((line) => line.length > 0) .join("\n"); }, - [columns, currentData, getCellDisplayValue, primaryKeys, tableContext], + [ + columns, + currentData, + getCellDisplayValue, + canUpdateDelete, + primaryKeys, + tableContext, + ], ); const normalizedSearchKeyword = searchKeyword.trim().toLowerCase(); @@ -1265,9 +1364,8 @@ export function TableView({ if (!pendingFocusDraftId) return; const selector = `input[data-draft-id="${pendingFocusDraftId}"][data-draft-col-index="0"]`; requestAnimationFrame(() => { - const target = containerRef.current?.querySelector( - selector, - ); + const target = + containerRef.current?.querySelector(selector); if (!target) return; target.scrollIntoView({ behavior: "smooth", @@ -1557,8 +1655,9 @@ export function TableView({ {t("connection.menu.newQuery")} )} - {isEditable && ( + {(canInsert || canUpdateDelete) && ( <> + {canInsert && ( + )} + {canUpdateDelete && ( + )} )} @@ -1752,38 +1854,22 @@ export function TableView({ }} /> - {tableContext && - (!isEditable || hasLocalClientSort) && - (primaryKeys.length === 0 || isReadOnlyDriver || hasLocalClientSort) && ( + {tableContext && mutabilityHint && ( - Read-only + {canInsert ? "Partial write" : "Read-only"} )} ) : ( - tableContext && - (!isEditable || hasLocalClientSort) && - (primaryKeys.length === 0 || isReadOnlyDriver || hasLocalClientSort) && ( + tableContext && mutabilityHint && ( - Read-only + {canInsert ? "Partial write" : "Read-only"} ) )} @@ -1942,7 +2028,10 @@ export function TableView({ }} onClick={() => handleCellClick(rowIndex, column)} onContextMenu={() => { - if (selectedRows.size > 1 && selectedRows.has(rowIndex)) { + if ( + selectedRows.size > 1 && + selectedRows.has(rowIndex) + ) { return; } handleCellClick(rowIndex, column); @@ -1969,7 +2058,7 @@ export function TableView({ ) : (
{displayValue !== null && - displayValue !== undefined ? ( + displayValue !== undefined ? ( - {isEditable && + {canUpdateDelete && isCellModified(rowIndex, selectedCell?.col || "") && ( <> )} - {isEditable && ( + {canUpdateDelete && ( { const sql = buildRowsUpdateSQL(copyTargetRows); diff --git a/src/components/business/DataGrid/tableView/utils.ts b/src/components/business/DataGrid/tableView/utils.ts index 5ee3075..04c4a3a 100644 --- a/src/components/business/DataGrid/tableView/utils.ts +++ b/src/components/business/DataGrid/tableView/utils.ts @@ -107,7 +107,11 @@ export function collectSearchMatches( currentData: any[], columns: string[], normalizedSearchKeyword: string, - getCellDisplayValue: (rowIndex: number, column: string, originalValue: any) => any, + getCellDisplayValue: ( + rowIndex: number, + column: string, + originalValue: any, + ) => any, ): SearchMatch[] { if (!normalizedSearchKeyword) { return []; @@ -259,3 +263,40 @@ export function getQualifiedTableName( return `${quoteIdent(driver, schema)}.${quoteIdent(driver, table)}`; } + +export function isClickHouseMergeTreeEngine( + engine: string | undefined | null, +): boolean { + if (!engine) return false; + return engine.toLowerCase().includes("mergetree"); +} + +export function canMutateClickHouseTable( + engine: string | undefined | null, + primaryKeys: string[], +): boolean { + return isClickHouseMergeTreeEngine(engine) && primaryKeys.length > 0; +} + +export function buildUpdateStatement( + driver: string, + tableName: string, + setClause: string, + whereClause: string, +): string { + if (driver === "clickhouse") { + return `ALTER TABLE ${tableName} UPDATE ${setClause} WHERE ${whereClause}`; + } + return `UPDATE ${tableName} SET ${setClause} WHERE ${whereClause}`; +} + +export function buildDeleteStatement( + driver: string, + tableName: string, + whereClause: string, +): string { + if (driver === "clickhouse") { + return `ALTER TABLE ${tableName} DELETE WHERE ${whereClause}`; + } + return `DELETE FROM ${tableName} WHERE ${whereClause}`; +} diff --git a/src/components/business/DataGrid/tableView/utils.unit.test.ts b/src/components/business/DataGrid/tableView/utils.unit.test.ts index 9e63303..ab32258 100644 --- a/src/components/business/DataGrid/tableView/utils.unit.test.ts +++ b/src/components/business/DataGrid/tableView/utils.unit.test.ts @@ -1,8 +1,12 @@ import { describe, expect, test } from "bun:test"; import { + buildDeleteStatement, + buildUpdateStatement, + canMutateClickHouseTable, formatInsertSQLValue, formatSQLValue, getQualifiedTableName, + isClickHouseMergeTreeEngine, isInsertColumnRequired, } from "./utils"; @@ -56,7 +60,11 @@ describe("formatInsertSQLValue", () => { test("formats boolean values", () => { expect( - formatInsertSQLValue("true", { name: "enabled", type: "boolean" }, "postgres"), + formatInsertSQLValue( + "true", + { name: "enabled", type: "boolean" }, + "postgres", + ), ).toBe("TRUE"); expect( formatInsertSQLValue("0", { name: "enabled", type: "boolean" }, "mssql"), @@ -65,7 +73,11 @@ describe("formatInsertSQLValue", () => { test("throws for invalid boolean values", () => { expect(() => - formatInsertSQLValue("yes", { name: "enabled", type: "boolean" }, "postgres"), + formatInsertSQLValue( + "yes", + { name: "enabled", type: "boolean" }, + "postgres", + ), ).toThrow('Invalid boolean value for column "enabled": "yes"'); }); @@ -84,9 +96,9 @@ describe("isInsertColumnRequired", () => { }); test("returns false when nullable", () => { - expect( - isInsertColumnRequired({ nullable: true, defaultValue: null }), - ).toBe(false); + expect(isInsertColumnRequired({ nullable: true, defaultValue: null })).toBe( + false, + ); }); test("returns false when default value exists", () => { @@ -101,7 +113,9 @@ describe("isInsertColumnRequired", () => { describe("getQualifiedTableName", () => { test("uses unqualified table with backticks for tidb", () => { - expect(getQualifiedTableName("tidb", "analytics", "events")).toBe("`events`"); + expect(getQualifiedTableName("tidb", "analytics", "events")).toBe( + "`events`", + ); }); test("uses unqualified table with backticks for mariadb", () => { @@ -111,20 +125,62 @@ describe("getQualifiedTableName", () => { }); test("does not qualify sqlite main/public schema", () => { - expect(getQualifiedTableName("sqlite", "main", "users")).toBe("\"users\""); - expect(getQualifiedTableName("sqlite", "public", "users")).toBe("\"users\""); - expect(getQualifiedTableName("sqlite", "", "users")).toBe("\"users\""); + expect(getQualifiedTableName("sqlite", "main", "users")).toBe('"users"'); + expect(getQualifiedTableName("sqlite", "public", "users")).toBe('"users"'); + expect(getQualifiedTableName("sqlite", "", "users")).toBe('"users"'); }); test("keeps non-main sqlite schema qualification", () => { expect(getQualifiedTableName("sqlite", "analytics", "events")).toBe( - "\"analytics\".\"events\"", + '"analytics"."events"', ); }); test("does not qualify duckdb main/public schema", () => { - expect(getQualifiedTableName("duckdb", "main", "users")).toBe("\"users\""); - expect(getQualifiedTableName("duckdb", "public", "users")).toBe("\"users\""); - expect(getQualifiedTableName("duckdb", "", "users")).toBe("\"users\""); + expect(getQualifiedTableName("duckdb", "main", "users")).toBe('"users"'); + expect(getQualifiedTableName("duckdb", "public", "users")).toBe('"users"'); + expect(getQualifiedTableName("duckdb", "", "users")).toBe('"users"'); + }); +}); + +describe("clickhouse mutation guards", () => { + test("detects mergetree engine variants", () => { + expect(isClickHouseMergeTreeEngine("MergeTree")).toBe(true); + expect(isClickHouseMergeTreeEngine("ReplacingMergeTree")).toBe(true); + expect(isClickHouseMergeTreeEngine("Memory")).toBe(false); + }); + + test("requires both mergetree engine and primary keys", () => { + expect(canMutateClickHouseTable("MergeTree", ["id"])).toBe(true); + expect(canMutateClickHouseTable("MergeTree", [])).toBe(false); + expect(canMutateClickHouseTable("Log", ["id"])).toBe(false); + }); +}); + +describe("mutation statement builders", () => { + test("builds clickhouse alter update/delete statements", () => { + expect( + buildUpdateStatement( + "clickhouse", + "`analytics`.`events`", + "`name` = 'new'", + "`id` = 1", + ), + ).toBe( + "ALTER TABLE `analytics`.`events` UPDATE `name` = 'new' WHERE `id` = 1", + ); + + expect( + buildDeleteStatement("clickhouse", "`analytics`.`events`", "`id` = 1"), + ).toBe("ALTER TABLE `analytics`.`events` DELETE WHERE `id` = 1"); + }); + + test("keeps generic update/delete statements for non-clickhouse", () => { + expect( + buildUpdateStatement("postgres", '"public"."users"', '"name" = \'new\'', '"id" = 1'), + ).toBe("UPDATE \"public\".\"users\" SET \"name\" = 'new' WHERE \"id\" = 1"); + expect(buildDeleteStatement("postgres", '"public"."users"', '"id" = 1')).toBe( + "DELETE FROM \"public\".\"users\" WHERE \"id\" = 1", + ); }); }); diff --git a/src/components/business/Editor/SqlEditor.tsx b/src/components/business/Editor/SqlEditor.tsx index 5eb67e1..b2a34fa 100644 --- a/src/components/business/Editor/SqlEditor.tsx +++ b/src/components/business/Editor/SqlEditor.tsx @@ -479,9 +479,12 @@ export function SqlEditor({ format, filePath, }); - toast.success(t("sqlEditor.export.completed", { count: result.rowCount }), { - description: result.filePath, - }); + toast.success( + t("sqlEditor.export.completed", { count: result.rowCount }), + { + description: result.filePath, + }, + ); } catch (e) { toast.error(t("sqlEditor.export.failed"), { description: e instanceof Error ? e.message : String(e), @@ -494,7 +497,10 @@ export function SqlEditor({ const triggerSave = useCallback(() => { const currentId = savedQueryIdRef.current; if (currentId) { - executeSave(initialName || t("sqlEditor.untitled"), initialDescription || ""); + executeSave( + initialName || t("sqlEditor.untitled"), + initialDescription || "", + ); } else { setIsSaveDialogOpen(true); } @@ -603,7 +609,10 @@ export function SqlEditor({ .filter((item): item is CompletionResult => !!item); if (!results.length) return null; - const from = results.reduce((min, curr) => Math.min(min, curr.from), results[0].from); + const from = results.reduce( + (min, curr) => Math.min(min, curr.from), + results[0].from, + ); const options: NonNullable[number][] = []; const seen = new Set(); for (const result of results) { @@ -691,16 +700,13 @@ export function SqlEditor({
- {databaseName && ( - canSwitchDatabase ? ( + {databaseName && + (canSwitchDatabase ? (
- )}
- ) - )} + ))}
diff --git a/src/components/business/Editor/codemirrorTheme.ts b/src/components/business/Editor/codemirrorTheme.ts index 792c346..c26e0ad 100644 --- a/src/components/business/Editor/codemirrorTheme.ts +++ b/src/components/business/Editor/codemirrorTheme.ts @@ -32,7 +32,7 @@ const baseThemeSpec: Parameters[0] = { }, ".cm-selectionBackground, &.cm-focused .cm-selectionBackground, .cm-content ::selection, &.cm-focused .cm-content ::selection": { - backgroundColor: "var(--editor-selection-bg) !important", + backgroundColor: "var(--editor-selection-bg) !important", }, ".cm-tooltip": { backgroundColor: "var(--popover)", diff --git a/src/components/business/Metadata/TableMetadataView.tsx b/src/components/business/Metadata/TableMetadataView.tsx index 7d56ca8..93aee78 100644 --- a/src/components/business/Metadata/TableMetadataView.tsx +++ b/src/components/business/Metadata/TableMetadataView.tsx @@ -207,7 +207,9 @@ export function TableMetadataView({
{clickhouseExtra.partitionKey && (
-
Partition Key
+
+ Partition Key +
{clickhouseExtra.partitionKey}
@@ -215,7 +217,9 @@ export function TableMetadataView({ )} {clickhouseExtra.sortingKey && (
-
Sorting Key
+
+ Sorting Key +
{clickhouseExtra.sortingKey}
@@ -233,7 +237,9 @@ export function TableMetadataView({ )} {clickhouseExtra.samplingKey && (
-
Sampling Key
+
+ Sampling Key +
{clickhouseExtra.samplingKey}
diff --git a/src/components/business/Sidebar/ConnectionList.tsx b/src/components/business/Sidebar/ConnectionList.tsx index 1f9f889..da6ffd9 100644 --- a/src/components/business/Sidebar/ConnectionList.tsx +++ b/src/components/business/Sidebar/ConnectionList.tsx @@ -16,6 +16,7 @@ import { Search, Download, FolderOpen, + Upload, } from "lucide-react"; import { Button } from "@/components/ui/button"; import { @@ -53,7 +54,7 @@ import { ContextMenuItem, ContextMenuTrigger, } from "@/components/ui/context-menu"; -import { api, isTauri } from "@/services/api"; +import { api, getImportDriverCapability, isTauri } from "@/services/api"; import type { ConnectionForm, CreateDatabasePayload, @@ -153,6 +154,7 @@ const createDatabaseSupportedDrivers: Driver[] = [ "mysql", "mariadb", "tidb", + "clickhouse", "mssql", ]; @@ -189,7 +191,6 @@ const mssqlCollationOptions = [ "Japanese_CI_AS", ]; const schemaNodeDrivers: Driver[] = ["postgres", "mssql"]; - interface ConnectionListProps { onTableSelect?: ( connection: string, @@ -257,7 +258,9 @@ export function ConnectionList({ const [expandedQueryGroups, setExpandedQueryGroups] = useState>( new Set(), ); - const [expandedSchemas, setExpandedSchemas] = useState>(new Set()); + const [expandedSchemas, setExpandedSchemas] = useState>( + new Set(), + ); const [expandedTables, setExpandedTables] = useState>(new Set()); const [selectedTableKey, setSelectedTableKey] = useState(null); const [autoScrollRequest, setAutoScrollRequest] = useState<{ @@ -283,17 +286,18 @@ export function ConnectionList({ const [isSavingEdit, setIsSavingEdit] = useState(false); const [isDeleting, setIsDeleting] = useState(false); const [isCreatingDatabase, setIsCreatingDatabase] = useState(false); + const [isImportingSql, setIsImportingSql] = useState(false); const [deleteTargetConnectionId, setDeleteTargetConnectionId] = useState< string | null >(null); - const [createDbConnectionId, setCreateDbConnectionId] = useState( - null, - ); + const [createDbConnectionId, setCreateDbConnectionId] = useState< + string | null + >(null); const [isCreateDbDialogOpen, setIsCreateDbDialogOpen] = useState(false); const [showCreateDbAdvanced, setShowCreateDbAdvanced] = useState(false); - const [createDbValidationMsg, setCreateDbValidationMsg] = useState( - null, - ); + const [createDbValidationMsg, setCreateDbValidationMsg] = useState< + string | null + >(null); const [createDbForm, setCreateDbForm] = useState( defaultCreateDatabaseForm, ); @@ -308,6 +312,13 @@ export function ConnectionList({ const [savedQueriesByConnection, setSavedQueriesByConnection] = useState< Record >({}); + const [pendingImport, setPendingImport] = useState<{ + connectionId: string; + databaseName: string; + driver: Driver; + filePath: string; + } | null>(null); + const [isImportConfirmOpen, setIsImportConfirmOpen] = useState(false); const supportsCreateDatabaseForDriver = (driver: Driver) => createDatabaseSupportedDrivers.includes(driver); @@ -425,7 +436,12 @@ export function ConnectionList({ return null; }) .filter(Boolean) as Connection[]; - }, [connections, savedQueriesByConnection, searchTerm, showSavedQueriesInTree]); + }, [ + connections, + savedQueriesByConnection, + searchTerm, + showSavedQueriesInTree, + ]); useEffect(() => { if (searchTerm) { @@ -474,11 +490,7 @@ export function ConnectionList({ }); } } - }, [ - searchTerm, - filteredConnections, - showSavedQueriesInTree, - ]); + }, [searchTerm, filteredConnections, showSavedQueriesInTree]); useEffect( () => () => { @@ -503,7 +515,10 @@ export function ConnectionList({ form.driver !== "mariadb" ); }, [form.driver]); - const normalizedForm = useMemo(() => normalizeConnectionFormInput(form), [form]); + const normalizedForm = useMemo( + () => normalizeConnectionFormInput(form), + [form], + ); const validationIssues = useMemo( () => validateConnectionFormInput( @@ -564,7 +579,10 @@ export function ConnectionList({ const selectedPath = await pickSingleFile({ title: t("connection.dialog.sslCaFileDialogTitle"), filters: [ - { name: t("connection.dialog.fileFilterCert"), extensions: ["pem", "crt", "cer"] }, + { + name: t("connection.dialog.fileFilterCert"), + extensions: ["pem", "crt", "cer"], + }, { name: t("connection.dialog.fileFilterAll"), extensions: ["*"] }, ], }); @@ -583,7 +601,10 @@ export function ConnectionList({ const selectedPath = await pickSingleFile({ title: t("connection.dialog.sshKeyFileDialogTitle"), filters: [ - { name: t("connection.dialog.fileFilterPem"), extensions: ["pem", "key", "ppk"] }, + { + name: t("connection.dialog.fileFilterPem"), + extensions: ["pem", "key", "ppk"], + }, { name: t("connection.dialog.fileFilterAll"), extensions: ["*"] }, ], }); @@ -664,7 +685,8 @@ export function ConnectionList({ const toggleConnection = (id: string) => { const connection = connections.find((conn) => conn.id === id); if (!connection) return; - if (connection.connectState !== "success" && !showSavedQueriesInTree) return; + if (connection.connectState !== "success" && !showSavedQueriesInTree) + return; const newExpanded = new Set(expandedConnections); if (newExpanded.has(id)) { @@ -1233,24 +1255,30 @@ export function ConnectionList({ conn.id === connectionId ? { ...conn, databases: [] } : conn, ), ); - setExpandedDatabases((prev) => - new Set([...prev].filter((key) => !key.startsWith(`${connectionId}-`))), + setExpandedDatabases( + (prev) => + new Set([...prev].filter((key) => !key.startsWith(`${connectionId}-`))), ); - setExpandedSchemas((prev) => - new Set([...prev].filter((key) => !key.startsWith(`${connectionId}-`))), + setExpandedSchemas( + (prev) => + new Set([...prev].filter((key) => !key.startsWith(`${connectionId}-`))), ); - setExpandedTables((prev) => - new Set([...prev].filter((key) => !key.startsWith(`${connectionId}-`))), + setExpandedTables( + (prev) => + new Set([...prev].filter((key) => !key.startsWith(`${connectionId}-`))), ); }; const handleCreateDatabase = async () => { const connection = createDbTargetConnection; - if (!connection || !supportsCreateDatabaseForDriver(connection.type)) return; + if (!connection || !supportsCreateDatabaseForDriver(connection.type)) + return; const name = createDbForm.name.trim(); if (!name) { - setCreateDbValidationMsg(t("connection.createDbDialog.validation.requiredName")); + setCreateDbValidationMsg( + t("connection.createDbDialog.validation.requiredName"), + ); return; } @@ -1259,16 +1287,19 @@ export function ConnectionList({ ifNotExists: createDbForm.ifNotExists, }; if (isMySqlFamilyCreateDb) { - if (createDbForm.charset.trim()) payload.charset = createDbForm.charset.trim(); + if (createDbForm.charset.trim()) + payload.charset = createDbForm.charset.trim(); if (createDbForm.collation.trim()) { payload.collation = createDbForm.collation.trim(); } } else if (isPostgresCreateDb) { - if (createDbForm.encoding.trim()) payload.encoding = createDbForm.encoding.trim(); + if (createDbForm.encoding.trim()) + payload.encoding = createDbForm.encoding.trim(); if (createDbForm.lcCollate.trim()) { payload.lcCollate = createDbForm.lcCollate.trim(); } - if (createDbForm.lcCtype.trim()) payload.lcCtype = createDbForm.lcCtype.trim(); + if (createDbForm.lcCtype.trim()) + payload.lcCtype = createDbForm.lcCtype.trim(); } else if (isMssqlCreateDb) { if (createDbForm.collation.trim()) { payload.collation = createDbForm.collation.trim(); @@ -1614,9 +1645,92 @@ export function ConnectionList({ } }; + const handleDatabaseImport = async ( + connectionId: string, + databaseName: string, + ) => { + const connection = connections.find((conn) => conn.id === connectionId); + if (!connection) return; + + const capability = getImportDriverCapability(connection.type); + if (capability === "read_only_not_supported") { + toast.error(t("connection.toast.importReadOnlyDriver")); + return; + } + + if (capability !== "supported") { + toast.error(t("connection.toast.importUnsupportedDriver")); + return; + } + + if (!isTauri()) { + toast.error(t("connection.toast.importDesktopOnly")); + return; + } + + const selectedPath = await pickSingleFile({ + title: t("connection.toast.selectImportSqlFile"), + filters: [{ name: "SQL", extensions: ["sql"] }], + }); + if (!selectedPath) return; + + setPendingImport({ + connectionId, + databaseName, + driver: connection.type, + filePath: selectedPath, + }); + setIsImportConfirmOpen(true); + }; + + const handleConfirmImport = async () => { + if (!pendingImport) return; + + setIsImportingSql(true); + try { + const result = await api.transfer.importSqlFile({ + id: Number(pendingImport.connectionId), + database: pendingImport.databaseName, + filePath: pendingImport.filePath, + driver: pendingImport.driver, + }); + + if (result.error || result.failedAt) { + toast.error(t("connection.toast.importFailed"), { + description: result.error || t("common.unknown"), + }); + } else { + toast.success( + t("connection.toast.importSuccess", { + count: result.successStatements, + }), + { + description: pendingImport.filePath, + }, + ); + } + + await handleRefreshDatabaseTables( + pendingImport.connectionId, + pendingImport.databaseName, + ); + } catch (e) { + toast.error(t("connection.toast.importFailed"), { + description: e instanceof Error ? e.message : String(e), + }); + } finally { + setIsImportingSql(false); + setIsImportConfirmOpen(false); + setPendingImport(null); + } + }; + const contextMenuConnection = contextMenu.connectionId ? connections.find((conn) => conn.id === contextMenu.connectionId) : null; + const contextMenuDatabaseConnection = contextMenu.connectionId + ? connections.find((conn) => conn.id === contextMenu.connectionId) + : null; return (
@@ -1662,7 +1776,9 @@ export function ConnectionList({
- +
- + @@ -1777,16 +1899,19 @@ export function ConnectionList({
- +
- + ({ ...f, ssl: checked === true })) } /> - +
{form.ssl && supportsSslCa && (
@@ -1879,7 +2010,9 @@ export function ConnectionList({